diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d2995145eca78921731630f4942cf63dc8a3fef..9385c5e3b008371b3593c8c8ac162f38a8ed18aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,23 +65,20 @@ permissions: jobs: tests: name: Test - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main - - - name: Free up more runner space - run: | - set +o pipefail - # large docker images - sudo docker image prune --all --force || true - # large packages - sudo apt-get purge -y '^llvm-.*' 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true - sudo apt-get autoremove -y - sudo apt-get clean - # large folders - sudo rm -rf /var/lib/apt/lists/* /usr/local/games /usr/local/sqlpackage /usr/local/.ghcup /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/local/share/vcpkg /usr/local/lib/python* /usr/local/lib/node_modules /usr/local/julia* /opt/mssql-tools /etc/skel /usr/share/vim /usr/share/postgresql /usr/share/man /usr/share/apache-maven-* /usr/share/R /usr/share/alsa /usr/share/miniconda /usr/share/grub /usr/share/gradle-* /usr/share/locale /usr/share/texinfo /usr/share/kotlinc /usr/share/swift /usr/share/doc /usr/share/az_9.3.0 /usr/share/sbt /usr/share/ri /usr/share/icons /usr/share/java /usr/share/fonts /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/R /usr/lib/postgresql /usr/lib/heroku /usr/lib/gcc - set -o pipefail + - name: Install liburing + run: | + sudo apt install liburing-dev -y + + - name: Free up a bit of runner space + run: | + set +o pipefail + sudo docker image prune --all --force || true + sudo apt purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true + sudo apt clean + sudo rm -v -rf /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/heroku + set -o pipefail - name: Sync repository uses: actions/checkout@v4 @@ -231,7 +228,7 @@ jobs: build: name: Build - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 needs: tests strategy: matrix: @@ -239,13 +236,10 @@ jobs: - target: aarch64-linux-musl - target: x86_64-linux-musl steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main - - name: Sync repository uses: actions/checkout@v4 - - uses: nixbuild/nix-quick-install-action@v28 + - uses: nixbuild/nix-quick-install-action@master - name: Restore and cache Nix store uses: nix-community/cache-nix-action@v5.1.0 @@ -450,6 +444,7 @@ jobs: steps: - name: Sync repository uses: actions/checkout@v4 + - name: Tag comparison check if: ${{ startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') }} run: | @@ -460,14 +455,17 @@ jobs: echo '# WARNING: Attempting to run this workflow for a tag that is not the latest repo tag. Aborting.' >> $GITHUB_STEP_SUMMARY exit 1 fi + # use sccache for Rust - name: Run sccache-cache if: (github.event.pull_request.draft != true) && (vars.DOCKER_USERNAME != '') && (vars.GITLAB_USERNAME != '') && (vars.SCCACHE_ENDPOINT != '') && (github.event.pull_request.user.login != 'renovate[bot]') uses: mozilla-actions/sccache-action@main + # use rust-cache - uses: Swatinem/rust-cache@v2 with: cache-all-crates: "true" + # Nix can't do portable macOS builds yet - name: Build macOS x86_64 binary if: ${{ matrix.os == 'macos-13' }} @@ -475,22 +473,26 @@ jobs: CONDUWUIT_VERSION_EXTRA="$(git rev-parse --short HEAD)" cargo build --release cp -v -f target/release/conduit conduwuit-macos-x86_64 otool -L conduwuit-macos-x86_64 + # quick smoke test of the x86_64 macOS binary - name: Run x86_64 macOS release binary if: ${{ matrix.os == 'macos-13' }} run: | ./conduwuit-macos-x86_64 --version + - name: Build macOS arm64 binary if: ${{ matrix.os == 'macos-latest' }} run: | CONDUWUIT_VERSION_EXTRA="$(git rev-parse --short HEAD)" cargo build --release cp -v -f target/release/conduit conduwuit-macos-arm64 otool -L conduwuit-macos-arm64 + # quick smoke test of the arm64 macOS binary - name: Run arm64 macOS release binary if: ${{ matrix.os == 'macos-latest' }} run: | ./conduwuit-macos-arm64 --version + - name: Upload macOS x86_64 binary if: ${{ matrix.os == 'macos-13' }} uses: actions/upload-artifact@v4 @@ -498,6 +500,7 @@ jobs: name: conduwuit-macos-x86_64 path: conduwuit-macos-x86_64 if-no-files-found: error + - name: Upload macOS arm64 binary if: ${{ matrix.os == 'macos-latest' }} uses: actions/upload-artifact@v4 @@ -508,7 +511,7 @@ jobs: docker: name: Docker publish - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 needs: build if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main' || (github.event.pull_request.draft != true)) && (vars.DOCKER_USERNAME != '') && (vars.GITLAB_USERNAME != '') && github.event.pull_request.user.login != 'renovate[bot]' env: diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 506a87d95d1ec7a607a4e9b4df010c529c2a647a..17b1f9c17cfce29990080780005cc184980cb398 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -39,7 +39,7 @@ concurrency: jobs: docs: name: Documentation and GitHub Pages - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 permissions: pages: write @@ -50,14 +50,20 @@ jobs: url: ${{ steps.deployment.outputs.page_url }} steps: - - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main + - name: Free up a bit of runner space + run: | + set +o pipefail + sudo docker image prune --all --force || true + sudo apt purge -y 'php.*' '^mongodb-.*' '^mysql-.*' azure-cli google-cloud-cli google-chrome-stable firefox powershell microsoft-edge-stable || true + sudo apt clean + sudo rm -v -rf /usr/local/games /usr/local/sqlpackage /usr/local/share/powershell /usr/local/share/edge_driver /usr/local/share/gecko_driver /usr/local/share/chromium /usr/local/share/chromedriver-linux64 /usr/lib/google-cloud-sdk /usr/lib/jvm /usr/lib/mono /usr/lib/heroku + set -o pipefail - name: Sync repository uses: actions/checkout@v4 - name: Setup GitHub Pages - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') uses: actions/configure-pages@v5 - uses: nixbuild/nix-quick-install-action@master @@ -139,12 +145,12 @@ jobs: compression-level: 0 - name: Upload generated documentation (book) as GitHub Pages artifact - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') uses: actions/upload-pages-artifact@v3 with: path: public - name: Deploy to GitHub Pages - if: github.event_name != 'pull_request' + if: (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') && (github.event_name != 'pull_request') && (github.event.pull_request.user.login == 'girlbossceo') id: deployment uses: actions/deploy-pages@v4 diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml deleted file mode 100644 index 1f0dd7df28107c89584791f6d242b85d5525dcae..0000000000000000000000000000000000000000 --- a/.github/workflows/trivy.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Trivy code and vulnerability scanning - -on: - pull_request: - push: - branches: - - main - tags: - - '*' - schedule: - - cron: '00 12 * * *' - -permissions: - contents: read - -jobs: - trivy-scan: - name: Trivy Scan - runs-on: ubuntu-latest - permissions: - contents: read - security-events: write - actions: read - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Run Trivy code and vulnerability scanner on repo - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: repo - format: sarif - output: trivy-results.sarif - severity: CRITICAL,HIGH,MEDIUM,LOW - - - name: Run Trivy code and vulnerability scanner on filesystem - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: fs - format: sarif - output: trivy-results.sarif - severity: CRITICAL,HIGH,MEDIUM,LOW diff --git a/Cargo.lock b/Cargo.lock index 6386f96857f64113a2fdd771673bef5eceaad7aa..3a95f83a508c8cf7d3cf1098afda4c57d18b48bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,15 +43,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.91" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arc-swap" @@ -76,6 +76,9 @@ name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +dependencies = [ + "serde", +] [[package]] name = "as_variant" @@ -124,7 +127,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -135,7 +138,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -161,9 +164,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-lc-rs" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdd82dba44d209fddb11c190e0a94b78651f95299598e472215667417a03ff1d" +checksum = "fe7c2840b66236045acd2607d5866e274380afd87ef99d6226e961e2cb47df45" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -173,9 +176,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.22.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df7a4168111d7eb622a31b214057b8509c0a7e1794f44c546d742330dc793972" +checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96" dependencies = [ "bindgen", "cc", @@ -188,9 +191,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "49c41b948da08fb481a94546cd874843adc1142278b0af4badf9b1b78599d68d" dependencies = [ "async-trait", "axum-core", @@ -254,9 +257,9 @@ dependencies = [ [[package]] name = "axum-extra" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73c3220b188aea709cf1b6c5f9b01c3bd936bb08bd2b5184a12b35ac8131b1f9" +checksum = "37634d71e9f3c35cfb1c30c87c7cba500d55892f04c2dbe6a99383c664b820b0" dependencies = [ "axum", "axum-core", @@ -272,7 +275,6 @@ dependencies = [ "tower 0.5.1", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -290,7 +292,7 @@ dependencies = [ "hyper", "hyper-util", "pin-project-lite", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pemfile", "rustls-pki-types", "tokio", @@ -310,7 +312,7 @@ dependencies = [ "http", "http-body-util", "pin-project", - "rustls 0.23.15", + "rustls 0.23.16", "tokio", "tokio-rustls", "tokio-util", @@ -370,7 +372,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.85", + "syn 2.0.87", "which", ] @@ -455,6 +457,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bytesize" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e368af43e418a04d52505cf3dbc23dda4e3407ae2fa99fd0e4f308ce546acc" + [[package]] name = "bzip2-sys" version = "0.1.11+1.0.8" @@ -478,9 +486,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.31" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -539,9 +547,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -549,9 +557,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstyle", "clap_lex", @@ -566,14 +574,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "cmake" @@ -592,7 +600,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "conduit" -version = "0.4.7" +version = "0.5.0" dependencies = [ "clap", "conduit_admin", @@ -621,15 +629,16 @@ dependencies = [ [[package]] name = "conduit_admin" -version = "0.4.7" +version = "0.5.0" dependencies = [ "clap", "conduit_api", "conduit_core", + "conduit_database", "conduit_macros", "conduit_service", "const-str", - "futures-util", + "futures", "log", "ruma", "serde_json", @@ -641,7 +650,7 @@ dependencies = [ [[package]] name = "conduit_api" -version = "0.4.7" +version = "0.5.0" dependencies = [ "axum", "axum-client-ip", @@ -652,7 +661,7 @@ dependencies = [ "conduit_database", "conduit_service", "const-str", - "futures-util", + "futures", "hmac", "http", "http-body-util", @@ -666,19 +675,20 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "sha-1", + "sha1", "tokio", "tracing", ] [[package]] name = "conduit_core" -version = "0.4.7" +version = "0.5.0" dependencies = [ "argon2", "arrayvec", "axum", "bytes", + "bytesize", "cargo_toml", "checked_ops", "chrono", @@ -689,6 +699,7 @@ dependencies = [ "cyborgtime", "either", "figment", + "futures", "hardened_malloc-rs", "http", "http-body-util", @@ -707,7 +718,8 @@ dependencies = [ "serde", "serde_json", "serde_regex", - "thiserror", + "serde_yaml", + "thiserror 2.0.3", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", "tikv-jemallocator", @@ -722,29 +734,33 @@ dependencies = [ [[package]] name = "conduit_database" -version = "0.4.7" +version = "0.5.0" dependencies = [ + "arrayvec", "conduit_core", "const-str", + "futures", "log", "rust-rocksdb-uwu", + "serde", + "serde_json", "tokio", "tracing", ] [[package]] name = "conduit_macros" -version = "0.4.7" +version = "0.5.0" dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "conduit_router" -version = "0.4.7" +version = "0.5.0" dependencies = [ "axum", "axum-client-ip", @@ -756,13 +772,14 @@ dependencies = [ "conduit_core", "conduit_service", "const-str", + "futures", "http", "http-body-util", "hyper", "hyper-util", "log", "ruma", - "rustls 0.23.15", + "rustls 0.23.16", "sd-notify", "sentry", "sentry-tower", @@ -776,15 +793,16 @@ dependencies = [ [[package]] name = "conduit_service" -version = "0.4.7" +version = "0.5.0" dependencies = [ + "arrayvec", "async-trait", "base64 0.22.1", "bytes", "conduit_core", "conduit_database", "const-str", - "futures-util", + "futures", "hickory-resolver", "http", "image", @@ -894,9 +912,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -1035,7 +1053,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1062,7 +1080,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1123,6 +1141,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "dunce" version = "1.0.5" @@ -1172,7 +1201,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1193,9 +1222,9 @@ dependencies = [ [[package]] name = "fdeflate" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" +checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" dependencies = [ "simd-adler32", ] @@ -1234,9 +1263,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1264,7 +1293,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" dependencies = [ "nonempty", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1283,6 +1312,20 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1324,7 +1367,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1345,6 +1388,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -1434,9 +1478,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "hdrhistogram" @@ -1510,7 +1554,7 @@ dependencies = [ "ipnet", "once_cell", "rand", - "thiserror", + "thiserror 1.0.69", "tinyvec", "tokio", "tracing", @@ -1533,7 +1577,7 @@ dependencies = [ "rand", "resolv-conf", "smallvec", - "thiserror", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -1589,7 +1633,7 @@ dependencies = [ "markup5ever", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1684,7 +1728,7 @@ dependencies = [ "http", "hyper", "hyper-util", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1726,6 +1770,124 @@ dependencies = [ "tracing", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "idna" version = "0.4.0" @@ -1738,19 +1900,30 @@ dependencies = [ [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] name = "image" -version = "0.25.4" +version = "0.25.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae" +checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b" dependencies = [ "bytemuck", "byteorder-lite", @@ -1790,7 +1963,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "serde", ] @@ -1906,11 +2079,9 @@ checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" dependencies = [ "base64 0.21.7", "js-sys", - "pem", "ring", "serde", "serde_json", - "simple_asn1", ] [[package]] @@ -1953,7 +2124,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1970,9 +2141,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.161" +version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" [[package]] name = "libloading" @@ -2007,6 +2178,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "lock_api" version = "0.4.12" @@ -2025,9 +2202,13 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loole" -version = "0.3.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad95468e4700cb37d8d1f198050db18cebe55e4b4c8aa9180a715deedb2f8965" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] [[package]] name = "lru-cache" @@ -2329,7 +2510,7 @@ dependencies = [ "js-sys", "once_cell", "pin-project-lite", - "thiserror", + "thiserror 1.0.69", "urlencoding", ] @@ -2372,10 +2553,10 @@ dependencies = [ "glob", "once_cell", "opentelemetry", - "ordered-float 4.4.0", + "ordered-float 4.5.0", "percent-encoding", "rand", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-stream", ] @@ -2391,9 +2572,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.4.0" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83e7ccb95e240b7c9506a3d544f10d935e142cc90b0a1d56954fb44d89ad6b97" +checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e" dependencies = [ "num-traits", ] @@ -2475,17 +2656,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.85", -] - -[[package]] -name = "pem" -version = "3.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" -dependencies = [ - "base64 0.22.1", - "serde", + "syn 2.0.87", ] [[package]] @@ -2568,7 +2739,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2640,7 +2811,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2669,7 +2840,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "version_check", "yansi", ] @@ -2694,7 +2865,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2738,45 +2909,49 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.15", + "rustls 0.23.16", "socket2", - "thiserror", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.15", + "rustls 0.23.16", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time 1.1.0", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2", @@ -2840,7 +3015,7 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -2855,9 +3030,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -2878,9 +3053,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "async-compression", "base64 0.22.1", @@ -2904,7 +3079,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types", @@ -2953,7 +3128,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "assign", "js_int", @@ -2975,7 +3150,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -2987,7 +3162,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "assign", @@ -3002,7 +3177,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror", + "thiserror 2.0.3", "url", "web-time 1.1.0", ] @@ -3010,7 +3185,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "base64 0.22.1", @@ -3028,7 +3203,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", - "thiserror", + "thiserror 2.0.3", "time", "tracing", "url", @@ -3040,7 +3215,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "as_variant", "indexmap 2.6.0", @@ -3054,7 +3229,7 @@ dependencies = [ "ruma-macros", "serde", "serde_json", - "thiserror", + "thiserror 2.0.3", "tracing", "url", "web-time 1.1.0", @@ -3064,7 +3239,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "bytes", "http", @@ -3082,16 +3257,16 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", - "thiserror", + "thiserror 2.0.3", ] [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -3101,7 +3276,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "cfg-if", "once_cell", @@ -3110,14 +3285,14 @@ dependencies = [ "quote", "ruma-identifiers-validation", "serde", - "syn 2.0.85", + "syn 2.0.87", "toml", ] [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "js_int", "ruma-common", @@ -3129,20 +3304,20 @@ dependencies = [ [[package]] name = "ruma-server-util" version = "0.3.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "headers", "http", "http-auth", "ruma-common", - "thiserror", + "thiserror 2.0.3", "tracing", ] [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -3152,28 +3327,29 @@ dependencies = [ "serde_json", "sha2", "subslice", - "thiserror", + "thiserror 2.0.3", ] [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?rev=9900d0676564883cfade556d6e8da2a2c9061efd#9900d0676564883cfade556d6e8da2a2c9061efd" +source = "git+https://github.com/girlbossceo/ruwuma?rev=2ab432fba19eb8862c594d24af39d8f9f6b4eac6#2ab432fba19eb8862c594d24af39d8f9f6b4eac6" dependencies = [ - "itertools 0.12.1", + "futures-util", + "itertools 0.13.0", "js_int", "ruma-common", "ruma-events", "serde", "serde_json", - "thiserror", + "thiserror 2.0.3", "tracing", ] [[package]] name = "rust-librocksdb-sys" -version = "0.28.0+9.7.3" -source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=c1e5523eae095a893deaf9056128c7dbc2d5fd73#c1e5523eae095a893deaf9056128c7dbc2d5fd73" +version = "0.29.0+9.7.4" +source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=2bc5495a9f8f75073390c326b47ee5928ab7c7f0#2bc5495a9f8f75073390c326b47ee5928ab7c7f0" dependencies = [ "bindgen", "bzip2-sys", @@ -3189,8 +3365,8 @@ dependencies = [ [[package]] name = "rust-rocksdb" -version = "0.31.0" -source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=c1e5523eae095a893deaf9056128c7dbc2d5fd73#c1e5523eae095a893deaf9056128c7dbc2d5fd73" +version = "0.33.0" +source = "git+https://github.com/girlbossceo/rust-rocksdb-zaidoon1?rev=2bc5495a9f8f75073390c326b47ee5928ab7c7f0#2bc5495a9f8f75073390c326b47ee5928ab7c7f0" dependencies = [ "libc", "rust-librocksdb-sys", @@ -3233,9 +3409,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" dependencies = [ "bitflags 2.6.0", "errno", @@ -3260,9 +3436,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.15" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ "aws-lc-rs", "log", @@ -3301,6 +3477,9 @@ name = "rustls-pki-types" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time 1.1.0", +] [[package]] name = "rustls-webpki" @@ -3330,7 +3509,7 @@ dependencies = [ "futures-util", "pin-project", "thingbuf", - "thiserror", + "thiserror 1.0.69", "unicode-segmentation", "unicode-width", ] @@ -3343,11 +3522,10 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "sanitize-filename" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ed72fbaf78e6f2d41744923916966c4fbe3d7c74e3037a8ee482f1115572603" +checksum = "bc984f4f9ceb736a7bb755c3e3bd17dc56370af2600c9780dcc48c66453da34d" dependencies = [ - "lazy_static", "regex", ] @@ -3387,9 +3565,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -3530,7 +3708,7 @@ dependencies = [ "rand", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "time", "url", "uuid", @@ -3538,22 +3716,22 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.213" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.213" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3635,17 +3813,6 @@ dependencies = [ "unsafe-libyaml", ] -[[package]] -name = "sha-1" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1" version = "0.10.6" @@ -3728,18 +3895,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "simple_asn1" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" -dependencies = [ - "num-bigint", - "num-traits", - "thiserror", - "time", -] - [[package]] name = "siphasher" version = "0.3.11" @@ -3787,6 +3942,12 @@ dependencies = [ "der", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strict" version = "0.2.0" @@ -3847,9 +4008,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -3871,6 +4032,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "tendril" version = "0.4.3" @@ -3884,9 +4056,9 @@ dependencies = [ [[package]] name = "termimad" -version = "0.30.1" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22117210909e9dfff30a558f554c7fb3edb198ef614e7691386785fb7679677c" +checksum = "9cda3a7471f9978706978454c45ef8dda67e9f8f3cdb9319eb2e9323deb6ae62" dependencies = [ "coolor", "crokey", @@ -3894,7 +4066,7 @@ dependencies = [ "lazy-regex", "minimad", "serde", - "thiserror", + "thiserror 1.0.69", "unicode-width", ] @@ -3910,22 +4082,42 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +dependencies = [ + "thiserror-impl 2.0.3", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -4019,6 +4211,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -4036,9 +4238,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -4060,7 +4262,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4081,7 +4283,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pki-types", "tokio", ] @@ -4094,7 +4296,7 @@ checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" dependencies = [ "either", "futures-util", - "thiserror", + "thiserror 1.0.69", "tokio", ] @@ -4274,7 +4476,7 @@ source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b47 dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4447,7 +4649,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.15", + "rustls 0.23.16", "rustls-pki-types", "url", "webpki-roots", @@ -4455,12 +4657,12 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna 1.0.3", "percent-encoding", "serde", ] @@ -4477,6 +4679,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" version = "1.11.0" @@ -4542,7 +4756,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -4576,7 +4790,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4906,6 +5120,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "xml5ever" version = "0.18.1" @@ -4923,6 +5149,30 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -4941,7 +5191,28 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "synstructure", ] [[package]] @@ -4950,6 +5221,28 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/Cargo.toml b/Cargo.toml index b75c4975713e310fe970e3503b26eae2ce06a1ef..68c87c572fbc393d08e1826959f43e46d5dd52b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,13 +20,14 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/girlbossceo/conduwuit" rust-version = "1.82.0" -version = "0.4.7" +version = "0.5.0" [workspace.metadata.crane] name = "conduit" [workspace.dependencies.arrayvec] version = "0.7.4" +features = ["std", "serde"] [workspace.dependencies.const-str] version = "0.5.7" @@ -45,20 +46,20 @@ default-features = false features = ["parse"] [workspace.dependencies.sanitize-filename] -version = "0.5.0" +version = "0.6.0" [workspace.dependencies.jsonwebtoken] version = "9.3.0" +default-features = false [workspace.dependencies.base64] version = "0.22.1" +default-features = false # used for TURN server authentication [workspace.dependencies.hmac] version = "0.12.1" - -[workspace.dependencies.sha-1] -version = "0.10.1" +default-features = false # used for checking if an IP is in specific subnets / CIDR ranges easier [workspace.dependencies.ipaddress] @@ -69,16 +70,16 @@ version = "0.8.5" # Used for the http request / response body type for Ruma endpoints used with reqwest [workspace.dependencies.bytes] -version = "1.7.2" +version = "1.8.0" [workspace.dependencies.http-body-util] -version = "0.1.1" +version = "0.1.2" [workspace.dependencies.http] version = "1.1.0" [workspace.dependencies.regex] -version = "1.10.6" +version = "1.11.1" [workspace.dependencies.axum] version = "0.7.5" @@ -94,7 +95,7 @@ features = [ ] [workspace.dependencies.axum-extra] -version = "0.9.3" +version = "0.9.4" default-features = false features = ["typed-header", "tracing"] @@ -115,7 +116,7 @@ default-features = false features = ["util"] [workspace.dependencies.tower-http] -version = "0.6.0" +version = "0.6.1" default-features = false features = [ "add-extension", @@ -128,10 +129,12 @@ features = [ ] [workspace.dependencies.rustls] -version = "0.23.13" +version = "0.23.16" +default-features = false +features = ["aws_lc_rs"] [workspace.dependencies.reqwest] -version = "0.12.8" +version = "0.12.9" default-features = false features = [ "rustls-tls-native-roots", @@ -141,12 +144,12 @@ features = [ ] [workspace.dependencies.serde] -version = "1.0.209" +version = "1.0.215" default-features = false features = ["rc"] [workspace.dependencies.serde_json] -version = "1.0.124" +version = "1.0.132" default-features = false features = ["raw_value"] @@ -170,7 +173,7 @@ default-features = false # Used to generate thumbnails for images [workspace.dependencies.image] -version = "0.25.1" +version = "0.25.5" default-features = false features = [ "jpeg", @@ -188,9 +191,11 @@ version = "0.1.40" default-features = false [workspace.dependencies.tracing-subscriber] version = "0.3.18" -features = ["env-filter"] +default-features = false +features = ["env-filter", "std", "tracing", "tracing-log", "ansi", "fmt"] [workspace.dependencies.tracing-core] version = "0.1.32" +default-features = false # for URL previews [workspace.dependencies.webpage] @@ -199,23 +204,26 @@ default-features = false # used for conduit's CLI and admin room command parsing [workspace.dependencies.clap] -version = "4.5.20" +version = "4.5.21" default-features = false features = [ "std", "derive", "help", + #"color", Do we need these? + #"unicode", "usage", "error-context", "string", ] -[workspace.dependencies.futures-util] +[workspace.dependencies.futures] version = "0.3.30" default-features = false +features = ["std", "async-await"] [workspace.dependencies.tokio] -version = "1.40.0" +version = "1.41.1" default-features = false features = [ "fs", @@ -236,7 +244,7 @@ version = "0.8.5" # Validating urls in config, was already a transitive dependency [workspace.dependencies.url] -version = "2.5.0" +version = "2.5.3" default-features = false features = ["serde"] @@ -256,26 +264,24 @@ features = [ ] [workspace.dependencies.hyper-util] -# 0.1.9 causes DNS issues +# hyper-util >=0.1.9 seems to have DNS issues version = "=0.1.8" default-features = false features = [ - "client", "server-auto", "server-graceful", - "service", "tokio", ] # to support multiple variations of setting a config option [workspace.dependencies.either] -version = "1.11.0" +version = "1.13.0" default-features = false features = ["serde"] # Used for reading the configuration from conduwuit.toml & environment variables [workspace.dependencies.figment] -version = "0.10.18" +version = "0.10.19" default-features = false features = ["env", "toml"] @@ -285,11 +291,13 @@ default-features = false # Used for conduit::Error type [workspace.dependencies.thiserror] -version = "1.0.63" +version = "2.0.3" +default-features = false # Used when hashing the state [workspace.dependencies.ring] version = "0.17.8" +default-features = false # Used to make working with iterators easier, was already a transitive depdendency [workspace.dependencies.itertools] @@ -302,10 +310,10 @@ version = "2.1.1" # used to replace the channels of the tokio runtime [workspace.dependencies.loole] -version = "0.3.1" +version = "0.4.0" [workspace.dependencies.async-trait] -version = "0.1.81" +version = "0.1.83" [workspace.dependencies.lru-cache] version = "0.1.2" @@ -314,7 +322,7 @@ version = "0.1.2" [workspace.dependencies.ruma] git = "https://github.com/girlbossceo/ruwuma" #branch = "conduwuit-changes" -rev = "9900d0676564883cfade556d6e8da2a2c9061efd" +rev = "2ab432fba19eb8862c594d24af39d8f9f6b4eac6" features = [ "compat", "rand", @@ -345,6 +353,7 @@ features = [ "unstable-msc4121", "unstable-msc4125", "unstable-msc4186", + "unstable-msc4210", # remove legacy mentions "unstable-extensible-events", ] @@ -360,9 +369,13 @@ features = [ "bzip2", ] -# optional SHA256 media keys feature [workspace.dependencies.sha2] version = "0.10.8" +default-features = false + +[workspace.dependencies.sha1] +version = "0.10.6" +default-features = false # optional opentelemetry, performance measurements, flamegraphs, etc for performance measurements and monitoring [workspace.dependencies.opentelemetry] @@ -430,7 +443,8 @@ default-features = false features = ["resource"] [workspace.dependencies.sd-notify] -version = "0.4.1" +version = "0.4.3" +default-features = false [workspace.dependencies.hardened_malloc-rs] version = "0.1.2" @@ -446,23 +460,25 @@ version = "0.4.3" default-features = false [workspace.dependencies.termimad] -version = "0.30.1" +version = "0.31.0" default-features = false [workspace.dependencies.checked_ops] version = "0.1" [workspace.dependencies.syn] -version = "2.0.76" +version = "2.0.87" default-features = false features = ["full", "extra-traits"] [workspace.dependencies.quote] -version = "1.0.36" +version = "1.0.37" [workspace.dependencies.proc-macro2] version = "1.0.89" +[workspace.dependencies.bytesize] +version = "1.3.0" # # Patches @@ -771,6 +787,7 @@ unused-qualifications = "warn" #unused-results = "warn" # TODO ## some sadness +elided_named_lifetimes = "allow" # TODO! let_underscore_drop = "allow" missing_docs = "allow" # cfgs cannot be limited to expected cfgs or their de facto non-transitive/opt-in use-case e.g. @@ -828,6 +845,7 @@ missing_panics_doc = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } struct_field_names = { level = "allow", priority = 1 } unnecessary_wraps = { level = "allow", priority = 1 } unused_async = { level = "allow", priority = 1 } diff --git a/README.md b/README.md index 962139d6487a197bc484b098b8720acad9451c89..4e97f1f00a730510744f1b9c855e98c9d1e8435a 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ### a very cool, featureful fork of [Conduit](https://conduit.rs/) <!-- ANCHOR_END: catchphrase --> Visit the [conduwuit documentation](https://conduwuit.puppyirl.gay/) for more -information. +information and how to deploy/setup conduwuit. <!-- ANCHOR: body --> @@ -63,7 +63,9 @@ #### Can I migrate or switch from Conduit? conduwuit is a complete drop-in replacement for Conduit. As long as you are using RocksDB, the only "migration" you need to do is replace the binary or container image. There -is no harm or additional steps required for using conduwuit. +is no harm or additional steps required for using conduwuit. See the +[Migrating from Conduit](https://conduwuit.puppyirl.gay/deploying/generic.html#migrating-from-conduit) section +on the generic deploying guide. <!-- ANCHOR_END: body --> diff --git a/bin/complement b/bin/complement index 601edb5a7725fc5894c75b524783a652b543cb63..a1db4b32562713f6c3880346eabb08e55b3ef117 100755 --- a/bin/complement +++ b/bin/complement @@ -18,7 +18,7 @@ RESULTS_FILE="$3" OCI_IMAGE="complement-conduwuit:main" # Complement tests that are skipped due to flakiness/reliability issues -SKIPPED_COMPLEMENT_TESTS='-skip=TestClientSpacesSummary.*|TestJoinFederatedRoomFromApplicationServiceBridgeUser.*|TestJumpToDateEndpoint.*' +SKIPPED_COMPLEMENT_TESTS='-skip=TestClientSpacesSummary.*|TestJoinFederatedRoomFromApplicationServiceBridgeUser.*|TestJumpToDateEndpoint.*|TestUnbanViaInvite.*' # $COMPLEMENT_SRC needs to be a directory to Complement source code if [ -f "$COMPLEMENT_SRC" ]; then diff --git a/clippy.toml b/clippy.toml index c942b93c7b87838a20c239902996f2378683cb7c..b93b2377505f52ab2833c601013b6bb789c4bfce 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,14 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 144000 # reduce me ALARA +stack-size-threshold = 196608 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 + +disallowed-macros = [ + { path = "log::error", reason = "use conduit_core::error" }, + { path = "log::warn", reason = "use conduit_core::warn" }, + { path = "log::info", reason = "use conduit_core::info" }, + { path = "log::debug", reason = "use conduit_core::debug" }, + { path = "log::trace", reason = "use conduit_core::trace" }, +] diff --git a/conduwuit-example.toml b/conduwuit-example.toml index b532d381f630e771eb42a80b70375b7492371085..2f3da71f4e5f225f8a4591ee56303bf9ac13d025 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -1,935 +1,1382 @@ -# ============================================================================= -# This is the official example config for conduwuit. -# If you use it for your server, you will need to adjust it to your own needs. -# At the very least, change the server_name field! -# -# This documentation can also be found at https://conduwuit.puppyirl.gay/configuration.html -# ============================================================================= +### conduwuit Configuration +### +### THIS FILE IS GENERATED. CHANGES/CONTRIBUTIONS IN THE REPO WILL +### BE OVERWRITTEN! +### +### You should rename this file before configuring your server. Changes +### to documentation and defaults can be contributed in source code at +### src/core/config/mod.rs. This file is generated when building. +### +### Any values pre-populated are the default values for said config option. +### +### At the minimum, you MUST edit all the config options to your environment +### that say "YOU NEED TO EDIT THIS". +### See https://conduwuit.puppyirl.gay/configuration.html for ways to +### configure conduwuit [global] -# The server_name is the pretty name of this server. It is used as a suffix for user -# and room ids. Examples: matrix.org, conduit.rs +# The server_name is the pretty name of this server. It is used as a +# suffix for user and room IDs/aliases. +# +# See the docs for reverse proxying and delegation: https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy +# Also see the `[global.well_known]` config section at the very bottom. +# +# Examples of delegation: +# - https://puppygock.gay/.well-known/matrix/server +# - https://puppygock.gay/.well-known/matrix/client +# +# YOU NEED TO EDIT THIS. THIS CANNOT BE CHANGED AFTER WITHOUT A DATABASE +# WIPE. +# +# example: "conduwuit.woof" +# +#server_name = -# The Conduit server needs all /_matrix/ requests to be reachable at -# https://your.server.name/ on port 443 (client-server) and 8448 (federation). +# default address (IPv4 or IPv6) conduwuit will listen on. +# +# If you are using Docker or a container NAT networking setup, this must +# be "0.0.0.0". +# +# To listen on multiple addresses, specify a vector e.g. ["127.0.0.1", +# "::1"] +# +#address = ["127.0.0.1", "::1"] -# If that's not possible for you, you can create /.well-known files to redirect -# requests (delegation). See -# https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient -# and -# https://spec.matrix.org/v1.9/server-server-api/#getwell-knownmatrixserver -# for more information +# The port(s) conduwuit will be running on. +# +# See https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy for reverse proxying. +# +# Docker users: Don't change this, you'll need to map an external port to +# this. +# +# To listen on multiple ports, specify a vector e.g. [8080, 8448] +# +#port = 8008 -# YOU NEED TO EDIT THIS -#server_name = "your.server.name" +# Uncomment unix_socket_path to listen on a UNIX socket at the specified +# path. If listening on a UNIX socket, you MUST remove/comment the +# 'address' key if definedm AND add your reverse proxy to the 'conduwuit' +# group, unless world RW permissions are specified with unix_socket_perms +# (666 minimum). +# +# example: "/run/conduwuit/conduwuit.sock" +# +#unix_socket_path = -# Servers listed here will be used to gather public keys of other servers (notary trusted key servers). +# The default permissions (in octal) to create the UNIX socket with. # -# The default behaviour for conduwuit is to attempt to query trusted key servers before querying the individual servers. -# This is done for performance reasons, but if you would like to query individual servers before the notary servers -# configured below, set to +#unix_socket_perms = 660 + +# This is the only directory where conduwuit will save its data, including +# media. +# Note: this was previously "/var/lib/matrix-conduit" +# +# YOU NEED TO EDIT THIS. # -# (Currently, conduwuit doesn't support batched key requests, so this list should only contain Synapse servers) -# Defaults to `matrix.org` -# trusted_servers = ["matrix.org"] +# example: "/var/lib/conduwuit" +# +#database_path = -# Sentry.io crash/panic reporting, performance monitoring/metrics, etc. This is NOT enabled by default. -# conduwuit's default Sentry reporting endpoint is o4506996327251968.ingest.us.sentry.io +# conduwuit supports online database backups using RocksDB's Backup engine +# API. To use this, set a database backup path that conduwuit can write +# to. # -# Defaults to *false* -#sentry = false +# See https://conduwuit.puppyirl.gay/maintenance.html#backups for more information. +# +# example: "/opt/conduwuit-db-backups" +# +#database_backup_path = -# Sentry reporting URL if a custom one is desired +# The amount of online RocksDB database backups to keep/retain, if using +# "database_backup_path", before deleting the oldest one. # -# Defaults to conduwuit's default Sentry endpoint: "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" -#sentry_endpoint = "" +#database_backups_to_keep = 1 -# Report your Conduwuit server_name in Sentry.io crash reports and metrics +# Set this to any float value in megabytes for conduwuit to tell the +# database engine that this much memory is available for database-related +# caches. # -# Defaults to false -#sentry_send_server_name = false +# May be useful if you have significant memory to spare to increase +# performance. +# +# Similar to the individual LRU caches, this is scaled up with your CPU +# core count. +# +# This defaults to 128.0 + (64.0 * CPU core count) +# +#db_cache_capacity_mb = -# Performance monitoring/tracing sample rate for Sentry.io +# Option to control adding arbitrary text to the end of the user's +# displayname upon registration with a space before the text. This was the +# lightning bolt emoji option, just replaced with support for adding your +# own custom text or emojis. To disable, set this to "" (an empty string). # -# Note that too high values may impact performance, and can be disabled by setting it to 0.0 (0%) -# This value is read as a percentage to Sentry, represented as a decimal +# The default is the trans pride flag. # -# Defaults to 15% of traces (0.15) -#sentry_traces_sample_rate = 0.15 +# example: "ðŸ³ï¸âš§ï¸" +# +#new_user_displayname_suffix = "ðŸ³ï¸âš§ï¸" -# Whether to attach a stacktrace to Sentry reports. -#sentry_attach_stacktrace = false +# If enabled, conduwuit will send a simple GET request periodically to +# `https://pupbrain.dev/check-for-updates/stable` for any new +# announcements made. Despite the name, this is not an update check +# endpoint, it is simply an announcement check endpoint. +# +# This is disabled by default as this is rarely used except for security +# updates or major updates. +# +#allow_check_for_updates = false -# Send panics to sentry. This is true by default, but sentry has to be enabled. -#sentry_send_panic = true +# Set this to any float value to multiply conduwuit's in-memory LRU caches +# with such as "auth_chain_cache_capacity". +# +# May be useful if you have significant memory to spare to increase +# performance. This was previously called +# `conduit_cache_capacity_modifier`. +# +# If you have low memory, reducing this may be viable. +# +# By default, the individual caches such as "auth_chain_cache_capacity" +# are scaled by your CPU core count. +# +#cache_capacity_modifier = 1.0 -# Send errors to sentry. This is true by default, but sentry has to be enabled. This option is -# only effective in release-mode; forced to false in debug-mode. -#sentry_send_error = true +# This item is undocumented. Please contribute documentation for it. +# +#pdu_cache_capacity = varies by system -# Controls the tracing log level for Sentry to send things like breadcrumbs and transactions -# Defaults to "info" -#sentry_filter = "info" +# This item is undocumented. Please contribute documentation for it. +# +#auth_chain_cache_capacity = varies by system +# This item is undocumented. Please contribute documentation for it. +# +#shorteventid_cache_capacity = varies by system -### Database configuration +# This item is undocumented. Please contribute documentation for it. +# +#eventidshort_cache_capacity = varies by system -# This is the only directory where conduwuit will save its data, including media. -# Note: this was previously "/var/lib/matrix-conduit" -database_path = "/var/lib/conduwuit" +# This item is undocumented. Please contribute documentation for it. +# +#shortstatekey_cache_capacity = varies by system -# Database backend: Only rocksdb is supported. -database_backend = "rocksdb" +# This item is undocumented. Please contribute documentation for it. +# +#statekeyshort_cache_capacity = varies by system +# This item is undocumented. Please contribute documentation for it. +# +#server_visibility_cache_capacity = varies by system -### Network +# This item is undocumented. Please contribute documentation for it. +# +#user_visibility_cache_capacity = varies by system -# The port(s) conduwuit will be running on. You need to set up a reverse proxy such as -# Caddy or Nginx so all requests to /_matrix on port 443 and 8448 will be -# forwarded to the conduwuit instance running on this port -# Docker users: Don't change this, you'll need to map an external port to this. -# To listen on multiple ports, specify a vector e.g. [8080, 8448] +# This item is undocumented. Please contribute documentation for it. # -# default if unspecified is 8008 -port = 6167 +#stateinfo_cache_capacity = varies by system -# default address (IPv4 or IPv6) conduwuit will listen on. Generally you want this to be -# localhost (127.0.0.1 / ::1). If you are using Docker or a container NAT networking setup, you -# likely need this to be 0.0.0.0. -# To listen multiple addresses, specify a vector e.g. ["127.0.0.1", "::1"] +# This item is undocumented. Please contribute documentation for it. # -# default if unspecified is both IPv4 and IPv6 localhost: ["127.0.0.1", "::1"] -address = "127.0.0.1" +#roomid_spacehierarchy_cache_capacity = varies by system -# Max request size for file uploads -max_request_size = 20_000_000 # in bytes +# Maximum entries stored in DNS memory-cache. The size of an entry may +# vary so please take care if raising this value excessively. Only +# decrease this when using an external DNS cache. Please note +# that systemd-resolved does *not* count as an external cache, even when +# configured to do so. +# +#dns_cache_entries = 32768 -# Uncomment unix_socket_path to listen on a UNIX socket at the specified path. -# If listening on a UNIX socket, you must remove/comment the 'address' key if defined and add your -# reverse proxy to the 'conduwuit' group, unless world RW permissions are specified with unix_socket_perms (666 minimum). -#unix_socket_path = "/run/conduwuit/conduwuit.sock" -#unix_socket_perms = 660 +# Minimum time-to-live in seconds for entries in the DNS cache. The +# default may appear high to most administrators; this is by design as the +# majority of NXDOMAINs are correct for a long time (e.g. the server is no +# longer running Matrix). Only decrease this if you are using an external +# DNS cache. +# +# default_dns_min_ttl: 259200 +# +#dns_min_ttl = -# Set this to true for conduwuit to compress HTTP response bodies using zstd. -# This option does nothing if conduwuit was not built with `zstd_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -zstd_compression = false - -# Set this to true for conduwuit to compress HTTP response bodies using gzip. -# This option does nothing if conduwuit was not built with `gzip_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -gzip_compression = false - -# Set this to true for conduwuit to compress HTTP response bodies using brotli. -# This option does nothing if conduwuit was not built with `brotli_compression` feature. -# Please be aware that enabling HTTP compression may weaken TLS. -# Most users should not need to enable this. -# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before deciding to enable this. -brotli_compression = false - -# Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you do not want conduwuit to send outbound requests to. -# Defaults to RFC1918, unroutable, loopback, multicast, and testnet addresses for security. +# Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. +# This value is critical for the server to federate efficiently. +# NXDOMAIN's are assumed to not be returning to the federation +# and aggressively cached rather than constantly rechecked. # -# To disable, set this to be an empty vector (`[]`). -# Please be aware that this is *not* a guarantee. You should be using a firewall with zones as doing this on the application layer may have bypasses. +# Defaults to 3 days as these are *very rarely* false negatives. # -# Currently this does not account for proxies in use like Synapse does. -ip_range_denylist = [ - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "100.64.0.0/10", - "192.0.0.0/24", - "169.254.0.0/16", - "192.88.99.0/24", - "198.18.0.0/15", - "192.0.2.0/24", - "198.51.100.0/24", - "203.0.113.0/24", - "224.0.0.0/4", - "::1/128", - "fe80::/10", - "fc00::/7", - "2001:db8::/32", - "ff00::/8", - "fec0::/10", -] - - -### Moderation / Privacy / Security - -# Config option to control whether the legacy unauthenticated Matrix media repository endpoints will be enabled. -# These endpoints consist of: -# - /_matrix/media/*/config -# - /_matrix/media/*/upload -# - /_matrix/media/*/preview_url -# - /_matrix/media/*/download/* -# - /_matrix/media/*/thumbnail/* +#dns_min_ttl_nxdomain = 259200 + +# Number of retries after a timeout. # -# The authenticated equivalent endpoints are always enabled. +#dns_attempts = 10 + +# The number of seconds to wait for a reply to a DNS query. Please note +# that recursive queries can take up to several seconds for some domains, +# so this value should not be too low, especially on slower hardware or +# resolvers. # -# Defaults to true for now, but this is highly subject to change, likely in the next release. -#allow_legacy_media = true +#dns_timeout = 10 -# Set to true to allow user type "guest" registrations. Element attempts to register guest users automatically. -# Defaults to false -allow_guest_registration = false +# Fallback to TCP on DNS errors. Set this to false if unsupported by +# nameserver. +# +#dns_tcp_fallback = true -# Set to true to log guest registrations in the admin room. -# Defaults to false as it may be noisy or unnecessary. -log_guest_registrations = false +# Enable to query all nameservers until the domain is found. Referred to +# as "trust_negative_responses" in hickory_resolver. This can avoid +# useless DNS queries if the first nameserver responds with NXDOMAIN or +# an empty NOERROR response. +# +#query_all_nameservers = true -# Set to true to allow guest registrations/users to auto join any rooms specified in `auto_join_rooms` -# Defaults to false -allow_guests_auto_join_rooms = false +# Enables using *only* TCP for querying your specified nameservers instead +# of UDP. +# +# If you are running conduwuit in a container environment, this config option may need to be enabled. See https://conduwuit.puppyirl.gay/troubleshooting.html#potential-dns-issues-when-using-docker for more details. +# +#query_over_tcp_only = false -# Vector list of servers that conduwuit will refuse to download remote media from. -# No default. -# prevent_media_downloads_from = ["example.com", "example.local"] +# DNS A/AAAA record lookup strategy +# +# Takes a number of one of the following options: +# 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) +# +# 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) +# +# 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever +# returns a successful response first) +# +# 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A +# record) +# +# 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA +# record) +# +# If you don't have IPv6 networking, then for better DNS performance it +# may be suitable to set this to Ipv4Only (1) as you will never ever use +# the AAAA record contents even if the AAAA record is successful instead +# of the A record. +# +#ip_lookup_strategy = 5 -# Enables registration. If set to false, no users can register on this -# server. -# If set to true without a token configured, users can register with no form of 2nd- -# step only if you set -# `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to -# true in your config. If you would like -# registration only via token reg, please configure the `registration_token` key. -allow_registration = false -# Please note that an open registration homeserver with no second-step verification -# is highly prone to abuse and potential defederation by homeservers, including -# matrix.org. - -# A static registration token that new users will have to provide when creating -# an account. If unset and `allow_registration` is true, registration is open -# without any condition. YOU NEED TO EDIT THIS. -registration_token = "change this token for something specific to your server" - -# controls whether federation is allowed or not -# defaults to true -# allow_federation = true - -# controls whether users are allowed to create rooms. -# appservices and admins are always allowed to create rooms -# defaults to true -# allow_room_creation = true - -# controls whether non-admin local users are forbidden from sending room invites (local and remote), -# and if non-admin users can receive remote room invites. admins are always allowed to send and receive all room invites. -# defaults to false -# block_non_admin_invites = false - -# List of forbidden username patterns/strings. Values in this list are matched as *contains*. -# This is checked upon username availability check, registration, and startup as warnings if any local users in your database -# have a forbidden username. -# No default. -# forbidden_usernames = [] - -# List of forbidden room aliases and room IDs as patterns/strings. Values in this list are matched as *contains*. -# This is checked upon room alias creation, custom room ID creation if used, and startup as warnings if any room aliases -# in your database have a forbidden room alias/ID. -# No default. -# forbidden_alias_names = [] - -# List of forbidden server names that we will block incoming AND outgoing federation with, and block client room joins / remote user invites. -# -# This check is applied on the room ID, room alias, sender server name, sender user's server name, inbound federation X-Matrix origin, and outbound federation handler. -# -# Basically "global" ACLs. No default. -# forbidden_remote_server_names = [] - -# List of forbidden server names that we will block all outgoing federated room directory requests for. Useful for preventing our users from wandering into bad servers or spaces. -# No default. -# forbidden_remote_room_directory_server_names = [] - -# Set this to true to allow your server's public room directory to be federated. -# Set this to false to protect against /publicRooms spiders, but will forbid external users -# from viewing your server's public room directory. If federation is disabled entirely -# (`allow_federation`), this is inherently false. -allow_public_room_directory_over_federation = false - -# Set this to true to allow your server's public room directory to be queried without client -# authentication (access token) through the Client APIs. Set this to false to protect against /publicRooms spiders. -allow_public_room_directory_without_auth = false - -# Set this to true to lock down your server's public room directory and only allow admins to publish rooms to the room directory. -# Unpublishing is still allowed by all users with this enabled. -# -# Defaults to false -lockdown_public_room_directory = false - -# Set this to true to allow federating device display names / allow external users to see your device display name. -# If federation is disabled entirely (`allow_federation`), this is inherently false. For privacy, this is best disabled. -allow_device_name_federation = false - -# Vector list of domains allowed to send requests to for URL previews. Defaults to none. -# Note: this is a *contains* match, not an explicit match. Putting "google.com" will match "https://google.com" and "http://mymaliciousdomainexamplegoogle.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_domain_contains_allowlist = [] - -# Vector list of explicit domains allowed to send requests to for URL previews. Defaults to none. -# Note: This is an *explicit* match, not a contains match. Putting "google.com" will match "https://google.com", "http://google.com", but not "https://mymaliciousdomainexamplegoogle.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_domain_explicit_allowlist = [] - -# Vector list of URLs allowed to send requests to for URL previews. Defaults to none. -# Note that this is a *contains* match, not an explicit match. Putting "google.com" will match "https://google.com/", "https://google.com/url?q=https://mymaliciousdomainexample.com", and "https://mymaliciousdomainexample.com/hi/google.com" -# Setting this to "*" will allow all URL previews. Please note that this opens up significant attack surface to your server, you are expected to be aware of the risks by doing so. -url_preview_url_contains_allowlist = [] - -# Vector list of explicit domains not allowed to send requests to for URL previews. Defaults to none. -# Note: This is an *explicit* match, not a contains match. Putting "google.com" will match "https://google.com", "http://google.com", but not "https://mymaliciousdomainexamplegoogle.com" -# The denylist is checked first before allowlist. Setting this to "*" will not do anything. -url_preview_domain_explicit_denylist = [] - -# Maximum amount of bytes allowed in a URL preview body size when spidering. Defaults to 384KB (384_000 bytes) -url_preview_max_spider_size = 384_000 - -# Option to decide whether you would like to run the domain allowlist checks (contains and explicit) on the root domain or not. Does not apply to URL contains allowlist. Defaults to false. -# Example: If this is enabled and you have "wikipedia.org" allowed in the explicit and/or contains domain allowlist, it will allow all subdomains under "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is checked and matched. -# Useful if the domain contains allowlist is still too broad for you but you still want to allow all the subdomains under a root domain. -url_preview_check_root_domain = false - -# Config option to allow or disallow incoming federation requests that obtain the profiles -# of our local users from `/_matrix/federation/v1/query/profile` +# Max request size for file uploads in bytes. Defaults to 20MB. # -# This is inherently false if `allow_federation` is disabled +#max_request_size = 20971520 + +# This item is undocumented. Please contribute documentation for it. # -# Defaults to true -allow_profile_lookup_federation_requests = true +#max_fetch_prev_events = 192 -# Config option to automatically deactivate the account of any user who attempts to join a: -# - banned room -# - forbidden room alias -# - room alias or ID with a forbidden server name +# Default/base connection timeout (seconds). This is used only by URL +# previews and update/news endpoint checks. # -# This may be useful if all your banned lists consist of toxic rooms or servers that no good faith user would ever attempt to join, and -# to automatically remediate the problem without any admin user intervention. +#request_conn_timeout = 10 + +# Default/base request timeout (seconds). The time waiting to receive more +# data from another server. This is used only by URL previews, +# update/news, and misc endpoint checks. # -# This will also make the user leave all rooms. Federation (e.g. remote room invites) are ignored here. +#request_timeout = 35 + +# Default/base request total timeout (seconds). The time limit for a whole +# request. This is set very high to not cancel healthy requests while +# serving as a backstop. This is used only by URL previews and +# update/news endpoint checks. # -# Defaults to false as rooms can be banned for non-moderation-related reasons -#auto_deactivate_banned_room_attempts = false +#request_total_timeout = 320 +# Default/base idle connection pool timeout (seconds). This is used only +# by URL previews and update/news endpoint checks. +# +#request_idle_timeout = 5 -### Admin Room and Console +# Default/base max idle connections per host. This is used only by URL +# previews and update/news endpoint checks. Defaults to 1 as generally the +# same open connection can be re-used. +# +#request_idle_per_host = 1 -# Controls whether the conduwuit admin room console / CLI will immediately activate on startup. -# This option can also be enabled with `--console` conduwuit argument +# Federation well-known resolution connection timeout (seconds) # -# Defaults to false -#admin_console_automatic = false +#well_known_conn_timeout = 6 -# Controls what admin commands will be executed on startup. This is a vector list of strings of admin commands to run. +# Federation HTTP well-known resolution request timeout (seconds) # -# An example of this can be: `admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` +#well_known_timeout = 10 + +# Federation client request timeout (seconds). You most definitely want +# this to be high to account for extremely large room joins, slow +# homeservers, your own resources etc. # -# This option can also be configured with the `--execute` conduwuit argument and can take standard shell commands and environment variables +#federation_timeout = 300 + +# Federation client idle connection pool timeout (seconds) # -# Such example could be: `./conduwuit --execute "server admin-notice conduwuit has started up at $(date)"` +#federation_idle_timeout = 25 + +# Federation client max idle connections per host. Defaults to 1 as +# generally the same open connection can be re-used # -# Defaults to nothing. -#admin_execute = [""] +#federation_idle_per_host = 1 -# Controls whether conduwuit should error and fail to start if an admin execute command (`--execute` / `admin_execute`) fails +# Federation sender request timeout (seconds). The time it takes for the +# remote server to process sent transactions can take a while. # -# Defaults to false -#admin_execute_errors_ignore = false +#sender_timeout = 180 -# Controls the max log level for admin command log captures (logs generated from running admin commands) +# Federation sender idle connection pool timeout (seconds) # -# Defaults to "info" on release builds, else "debug" on debug builds -#admin_log_capture = info +#sender_idle_timeout = 180 -# Allows admins to enter commands in rooms other than #admins by prefixing with \!admin. The reply -# will be publicly visible to the room, originating from the sender. -# defaults to true -#admin_escape_commands = true +# Federation sender transaction retry backoff limit (seconds) +# +#sender_retry_backoff_limit = 86400 -# Controls whether admin room notices like account registrations, password changes, account deactivations, -# room directory publications, etc will be sent to the admin room. +# Appservice URL request connection timeout. Defaults to 35 seconds as +# generally appservices are hosted within the same network. # -# Update notices and normal admin command responses will still be sent. +#appservice_timeout = 35 + +# Appservice URL idle connection pool timeout (seconds) # -# defaults to true -#admin_room_notices = true +#appservice_idle_timeout = 300 +# Notification gateway pusher idle connection pool timeout +# +#pusher_idle_timeout = 15 -### Misc +# Enables registration. If set to false, no users can register on this +# server. +# +# If set to true without a token configured, users can register with no +# form of 2nd-step only if you set +# `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to +# true in your config. +# +# If you would like registration only via token reg, please configure +# `registration_token` or `registration_token_file`. +# +#allow_registration = false -# max log level for conduwuit. allows debug, info, warn, or error -# see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives -# **Caveat**: -# For release builds, the tracing crate is configured to only implement levels higher than error to avoid unnecessary overhead in the compiled binary from trace macros. -# For debug builds, this restriction is not applied. +# This item is undocumented. Please contribute documentation for it. # -# Defaults to "info" -#log = "info" +#yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = false -# controls whether logs will be outputted with ANSI colours +# A static registration token that new users will have to provide when +# creating an account. If unset and `allow_registration` is true, +# registration is open without any condition. # -# defaults to true -#log_colors = true +# YOU NEED TO EDIT THIS OR USE registration_token_file. +# +# example: "o&^uCtes4HPf0Vu@F20jQeeWE7" +# +#registration_token = -# controls whether encrypted rooms and events are allowed (default true) -#allow_encryption = false +# Path to a file on the system that gets read for the registration token. +# this config option takes precedence/priority over "registration_token". +# +# conduwuit must be able to access the file, and it must not be empty +# +# example: "/etc/conduwuit/.reg_token" +# +#registration_token_file = -# if enabled, conduwuit will send a simple GET request periodically to `https://pupbrain.dev/check-for-updates/stable` -# for any new announcements made. Despite the name, this is not an update check -# endpoint, it is simply an announcement check endpoint. -# Defaults to false. -#allow_check_for_updates = false +# Controls whether encrypted rooms and events are allowed. +# +#allow_encryption = true -# Set to false to disable users from joining or creating room versions that aren't 100% officially supported by conduwuit. -# conduwuit officially supports room versions 6 - 10. conduwuit has experimental/unstable support for 3 - 5, and 11. -# Defaults to true. -#allow_unstable_room_versions = true +# Controls whether federation is allowed or not. It is not recommended to +# disable this after the fact due to potential federation breakage. +# +#allow_federation = true -# Option to control adding arbitrary text to the end of the user's displayname upon registration with a space before the text. -# This was the lightning bolt emoji option, just replaced with support for adding your own custom text or emojis. -# To disable, set this to "" (an empty string) -# Defaults to "ðŸ³ï¸â€âš§ï¸" (trans pride flag) -#new_user_displayname_suffix = "ðŸ³ï¸â€âš§ï¸" +# This item is undocumented. Please contribute documentation for it. +# +#federation_loopback = false -# Option to control whether conduwuit will query your list of trusted notary key servers (`trusted_servers`) for -# remote homeserver signing keys it doesn't know *first*, or query the individual servers first before falling back to the trusted -# key servers. +# Set this to true to require authentication on the normally +# unauthenticated profile retrieval endpoints (GET) +# "/_matrix/client/v3/profile/{userId}". # -# The former/default behaviour makes federated/remote rooms joins generally faster because we're querying a single (or list of) server -# that we know works, is reasonably fast, and is reliable for just about all the homeserver signing keys in the room. Querying individual -# servers may take longer depending on the general infrastructure of everyone in there, how many dead servers there are, etc. +# This can prevent profile scraping. # -# However, this does create an increased reliance on one single or multiple large entities as `trusted_servers` should generally -# contain long-term and large servers who know a very large number of homeservers. +#require_auth_for_profile_requests = false + +# Set this to true to allow your server's public room directory to be +# federated. Set this to false to protect against /publicRooms spiders, +# but will forbid external users from viewing your server's public room +# directory. If federation is disabled entirely (`allow_federation`), +# this is inherently false. # -# If you don't know what any of this means, leave this and `trusted_servers` alone to their defaults. +#allow_public_room_directory_over_federation = false + +# Set this to true to allow your server's public room directory to be +# queried without client authentication (access token) through the Client +# APIs. Set this to false to protect against /publicRooms spiders. # -# Defaults to true as this is the fastest option for federation. -#query_trusted_key_servers_first = true +#allow_public_room_directory_without_auth = false -# List/vector of room **IDs** that conduwuit will make newly registered users join. -# The room IDs specified must be rooms that you have joined at least once on the server, and must be public. +# allow guests/unauthenticated users to access TURN credentials # -# No default. -#auto_join_rooms = [] +# this is the equivalent of Synapse's `turn_allow_guests` config option. +# this allows any unauthenticated user to call the endpoint +# `/_matrix/client/v3/voip/turnServer`. +# +# It is unlikely you need to enable this as all major clients support +# authentication for this endpoint and prevents misuse of your TURN server +# from potential bots. +# +#turn_allow_guests = false -# Retry failed and incomplete messages to remote servers immediately upon startup. This is called bursting. -# If this is disabled, said messages may not be delivered until more messages are queued for that server. -# Do not change this option unless server resources are extremely limited or the scale of the server's -# deployment is huge. Do not disable this unless you know what you are doing. -#startup_netburst = true +# Set this to true to lock down your server's public room directory and +# only allow admins to publish rooms to the room directory. Unpublishing +# is still allowed by all users with this enabled. +# +#lockdown_public_room_directory = false -# Limit the startup netburst to the most recent (default: 50) messages queued for each remote server. All older -# messages are dropped and not reattempted. The `startup_netburst` option must be enabled for this value to have -# any effect. Do not change this value unless you know what you are doing. Set this value to -1 to reattempt -# every message without trimming the queues; this may consume significant disk. Set this value to 0 to drop all -# messages without any attempt at redelivery. -#startup_netburst_keep = 50 +# Set this to true to allow federating device display names / allow +# external users to see your device display name. If federation is +# disabled entirely (`allow_federation`), this is inherently false. For +# privacy reasons, this is best left disabled. +# +#allow_device_name_federation = false + +# Config option to allow or disallow incoming federation requests that +# obtain the profiles of our local users from +# `/_matrix/federation/v1/query/profile` +# +# Increases privacy of your local user's such as display names, but some +# remote users may get a false "this user does not exist" error when they +# try to invite you to a DM or room. Also can protect against profile +# spiders. +# +# This is inherently false if `allow_federation` is disabled +# +#allow_inbound_profile_lookup_federation_requests = true + +# controls whether standard users are allowed to create rooms. appservices +# and admins are always allowed to create rooms +# +#allow_room_creation = true + +# Set to false to disable users from joining or creating room versions +# that aren't 100% officially supported by conduwuit. +# +# conduwuit officially supports room versions 6 - 11. +# +# conduwuit has slightly experimental (though works fine in practice) +# support for versions 3 - 5 +# +#allow_unstable_room_versions = true + +# default room version conduwuit will create rooms with. +# +# per spec, room version 10 is the default. +# +#default_room_version = 10 -# If the 'perf_measurements' feature is enabled, enables collecting folded stack trace profile of tracing spans using -# tracing_flame. The resulting profile can be visualized with inferno[1], speedscope[2], or a number of other tools. +# This item is undocumented. Please contribute documentation for it. +# +#allow_jaeger = false + +# This item is undocumented. Please contribute documentation for it. +# +#jaeger_filter = "info" + +# If the 'perf_measurements' compile-time feature is enabled, enables +# collecting folded stack trace profile of tracing spans using +# tracing_flame. The resulting profile can be visualized with inferno[1], +# speedscope[2], or a number of other tools. +# # [1]: https://github.com/jonhoo/inferno # [2]: www.speedscope.app -# tracing_flame = false +# +#tracing_flame = false -# If 'tracing_flame' is enabled, sets a filter for which events will be included in the profile. -# Supported syntax is documented at https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives -# tracing_flame_filter = "trace,h2=off" +# This item is undocumented. Please contribute documentation for it. +# +#tracing_flame_filter = "info" -# If 'tracing_flame' is enabled, set the path to write the generated profile. -# tracing_flame_output_path = "./tracing.folded" +# This item is undocumented. Please contribute documentation for it. +# +#tracing_flame_output_path = "./tracing.folded" -# Enable the tokio-console. This option is only relevant to developers. -# See: docs/development.md#debugging-with-tokio-console for more information. -#tokio_console = false +# Examples: +# - No proxy (default): +# proxy ="none" +# +# - For global proxy, create the section at the bottom of this file: +# [global.proxy] +# global = { url = "socks5h://localhost:9050" } +# +# - To proxy some domains: +# [global.proxy] +# [[global.proxy.by_domain]] +# url = "socks5h://localhost:9050" +# include = ["*.onion", "matrix.myspecial.onion"] +# exclude = ["*.myspecial.onion"] +# +# Include vs. Exclude: +# - If include is an empty list, it is assumed to be `["*"]`. +# - If a domain matches both the exclude and include list, the proxy will +# only be used if it was included because of a more specific rule than +# it was excluded. In the above example, the proxy would be used for +# `ordinary.onion`, `matrix.myspecial.onion`, but not +# `hello.myspecial.onion`. +# +#proxy = "none" -# Enable backward-compatibility with Conduit's media directory by creating symlinks of media. This -# option is only necessary if you plan on using Conduit again. Otherwise setting this to false -# reduces filesystem clutter and overhead for managing these symlinks in the directory. This is now -# disabled by default. You may still return to upstream Conduit but you have to run Conduwuit at -# least once with this set to true and allow the media_startup_check to take place before shutting -# down to return to Conduit. +# This item is undocumented. Please contribute documentation for it. # -# Disabled by default. -#media_compat_file_link = false +#jwt_secret = -# Prunes missing media from the database as part of the media startup checks. This means if you -# delete files from the media directory the corresponding entries will be removed from the -# database. This is disabled by default because if the media directory is accidentally moved or -# inaccessible the metadata entries in the database will be lost with sadness. +# Servers listed here will be used to gather public keys of other servers +# (notary trusted key servers). # -# Disabled by default. -#prune_missing_media = false +# Currently, conduwuit doesn't support inbound batched key requests, so +# this list should only contain other Synapse servers +# +# example: ["matrix.org", "constellatory.net", "tchncs.de"] +# +#trusted_servers = ["matrix.org"] -# Checks consistency of the media directory at startup: -# 1. When `media_compat_file_link` is enbled, this check will upgrade media when switching back -# and forth between Conduit and Conduwuit. Both options must be enabled to handle this. -# 2. When media is deleted from the directory, this check will also delete its database entry. +# Whether to query the servers listed in trusted_servers first or query +# the origin server first. For best security, querying the origin server +# first is advised to minimize the exposure to a compromised trusted +# server. For maximum federation/join performance this can be set to true, +# however other options exist to query trusted servers first under +# specific high-load circumstances and should be evaluated before setting +# this to true. # -# If none of these checks apply to your use cases, and your media directory is significantly large -# setting this to false may reduce startup time. +#query_trusted_key_servers_first = false + +# Whether to query the servers listed in trusted_servers first +# specifically on room joins. This option limits the exposure to a +# compromised trusted server to room joins only. The join operation +# requires gathering keys from many origin servers which can cause +# significant delays. Therefor this defaults to true to mitigate +# unexpected delays out-of-the-box. The security-paranoid or those +# willing to tolerate delays are advised to set this to false. Note that +# setting query_trusted_key_servers_first to true causes this option to +# be ignored. # -# Enabled by default. -#media_startup_check = true +#query_trusted_key_servers_first_on_join = true + +# Only query trusted servers for keys and never the origin server. This is +# intended for clusters or custom deployments using their trusted_servers +# as forwarding-agents to cache and deduplicate requests. Notary servers +# do not act as forwarding-agents by default, therefor do not enable this +# unless you know exactly what you are doing. +# +#only_query_trusted_key_servers = false + +# Maximum number of keys to request in each trusted server batch query. +# +#trusted_server_batch_size = 1024 + +# max log level for conduwuit. allows debug, info, warn, or error +# see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives +# +# **Caveat**: +# For release builds, the tracing crate is configured to only implement +# levels higher than error to avoid unnecessary overhead in the compiled +# binary from trace macros. For debug builds, this restriction is not +# applied. +# +#log = "info" + +# controls whether logs will be outputted with ANSI colours +# +#log_colors = true + +# configures the span events which will be outputted with the log +# +#log_span_events = "none" # OpenID token expiration/TTL in seconds # -# These are the OpenID tokens that are primarily used for Matrix account integrations, *not* OIDC/OpenID Connect/etc +# These are the OpenID tokens that are primarily used for Matrix account +# integrations (e.g. Vector Integrations in Element), *not* OIDC/OpenID +# Connect/etc # -# Defaults to 3600 (1 hour) #openid_token_ttl = 3600 -# Emergency password feature. This password set here will let you login to the server service account (e.g. `@conduit`) -# and let you run admin commands, invite yourself to the admin room, etc. +# static TURN username to provide the client if not using a shared secret +# ("turn_secret"), It is recommended to use a shared secret over static +# credentials. # -# no default. -#emergency_password = "" - +#turn_username = false -### Generic database options +# static TURN password to provide the client if not using a shared secret +# ("turn_secret"). It is recommended to use a shared secret over static +# credentials. +# +#turn_password = false -# Set this to any float value to multiply conduwuit's in-memory LRU caches with. -# By default, the caches scale automatically with cpu-core-count. -# May be useful if you have significant memory to spare to increase performance. +# vector list of TURN URIs/servers to use # -# This was previously called `conduit_cache_capacity_modifier` +# replace "example.turn.uri" with your TURN domain, such as the coturn +# "realm" config option. if using TURN over TLS, replace the URI prefix +# "turn:" with "turns:" # -# Defaults to 1.0. -#cache_capacity_modifier = 1.0 +# example: ["turn:example.turn.uri?transport=udp", +# "turn:example.turn.uri?transport=tcp"] +# +#turn_uris = [] -# Set this to any float value in megabytes for conduwuit to tell the database engine that this much memory is available for database-related caches. -# May be useful if you have significant memory to spare to increase performance. -# Defaults to 128.0 + (64.0 * CPU core count). -#db_cache_capacity_mb = 256.0 +# TURN secret to use for generating the HMAC-SHA1 hash apart of username +# and password generation +# +# this is more secure, but if needed you can use traditional +# static username/password credentials. +# +#turn_secret = false +# TURN secret to use that's read from the file path specified +# +# this takes priority over "turn_secret" first, and falls back to +# "turn_secret" if invalid or failed to open. +# +# example: "/etc/conduwuit/.turn_secret" +# +#turn_secret_file = -### RocksDB options +# TURN TTL in seconds +# +#turn_ttl = 86400 -# Set this to true to use RocksDB config options that are tailored to HDDs (slower device storage) +# List/vector of room IDs or room aliases that conduwuit will make newly +# registered users join. The rooms specified must be rooms that you +# have joined at least once on the server, and must be public. # -# It is worth noting that by default, conduwuit will use RocksDB with Direct IO enabled. *Generally* speaking this improves performance as it bypasses buffered I/O (system page cache). -# However there is a potential chance that Direct IO may cause issues with database operations if your setup is uncommon. This has been observed with FUSE filesystems, and possibly ZFS filesystem. -# RocksDB generally deals/corrects these issues but it cannot account for all setups. -# If you experience any weird RocksDB issues, try enabling this option as it turns off Direct IO and feel free to report in the conduwuit Matrix room if this option fixes your DB issues. -# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. +# example: ["#conduwuit:puppygock.gay", +# "!eoIzvAvVwY23LPDay8:puppygock.gay"] # -# Defaults to false -#rocksdb_optimize_for_spinning_disks = false +#auto_join_rooms = [] -# Enables direct-io to increase database performance. This is enabled by default. Set this option to false if the -# database resides on a filesystem which does not support direct-io. -#rocksdb_direct_io = true +# Config option to automatically deactivate the account of any user who +# attempts to join a: +# - banned room +# - forbidden room alias +# - room alias or ID with a forbidden server name +# +# This may be useful if all your banned lists consist of toxic rooms or +# servers that no good faith user would ever attempt to join, and +# to automatically remediate the problem without any admin user +# intervention. +# +# This will also make the user leave all rooms. Federation (e.g. remote +# room invites) are ignored here. +# +# Defaults to false as rooms can be banned for non-moderation-related +# reasons +# +#auto_deactivate_banned_room_attempts = false -# RocksDB log level. This is not the same as conduwuit's log level. This is the log level for the RocksDB engine/library -# which show up in your database folder/path as `LOG` files. Defaults to error. conduwuit will typically log RocksDB errors as normal. +# RocksDB log level. This is not the same as conduwuit's log level. This +# is the log level for the RocksDB engine/library which show up in your +# database folder/path as `LOG` files. conduwuit will log RocksDB errors +# as normal through tracing. +# #rocksdb_log_level = "error" -# Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB. +# This item is undocumented. Please contribute documentation for it. +# +#rocksdb_log_stderr = false + +# Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB in +# bytes. +# #rocksdb_max_log_file_size = 4194304 -# Time in seconds before RocksDB will forcibly rotate logs. Defaults to 0. +# Time in seconds before RocksDB will forcibly rotate logs. +# #rocksdb_log_time_to_roll = 0 -# Amount of threads that RocksDB will use for parallelism on database operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use all your logical threads. +# Set this to true to use RocksDB config options that are tailored to HDDs +# (slower device storage) # -# Defaults to your CPU logical thread count. -#rocksdb_parallelism_threads = 0 +# It is worth noting that by default, conduwuit will use RocksDB with +# Direct IO enabled. *Generally* speaking this improves performance as it +# bypasses buffered I/O (system page cache). However there is a potential +# chance that Direct IO may cause issues with database operations if your +# setup is uncommon. This has been observed with FUSE filesystems, and +# possibly ZFS filesystem. RocksDB generally deals/corrects these issues +# but it cannot account for all setups. If you experience any weird +# RocksDB issues, try enabling this option as it turns off Direct IO and +# feel free to report in the conduwuit Matrix room if this option fixes +# your DB issues. +# +# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. +# +#rocksdb_optimize_for_spinning_disks = false -# Enables idle IO priority for compaction thread. This prevents any unexpected lag in the server's operation and -# is usually a good idea. Enabled by default. -#rocksdb_compaction_ioprio_idle = true +# Enables direct-io to increase database performance via unbuffered I/O. +# +# See https://github.com/facebook/rocksdb/wiki/Direct-IO for more details about Direct IO and RocksDB. +# +# Set this option to false if the database resides on a filesystem which +# does not support direct-io like FUSE, or any form of complex filesystem +# setup such as possibly ZFS. +# +#rocksdb_direct_io = true -# Enables idle CPU priority for compaction thread. This is not enabled by default to prevent compaction from -# falling too far behind on busy systems. -#rocksdb_compaction_prio_idle = false +# Amount of threads that RocksDB will use for parallelism on database +# operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use +# all your logical threads. Defaults to your CPU logical thread count. +# +#rocksdb_parallelism_threads = 0 -# Maximum number of LOG files RocksDB will keep. This must *not* be set to 0. It must be at least 1. -# Defaults to 3 as these are not very useful. +# Maximum number of LOG files RocksDB will keep. This must *not* be set to +# 0. It must be at least 1. Defaults to 3 as these are not very useful +# unless troubleshooting/debugging a RocksDB bug. +# #rocksdb_max_log_files = 3 # Type of RocksDB database compression to use. +# # Available options are "zstd", "zlib", "bz2", "lz4", or "none" -# It is best to use ZSTD as an overall good balance between speed/performance, storage, IO amplification, and CPU usage. -# For more performance but less compression (more storage used) and less CPU usage, use LZ4. -# See https://github.com/facebook/rocksdb/wiki/Compression for more details. +# +# It is best to use ZSTD as an overall good balance between +# speed/performance, storage, IO amplification, and CPU usage. +# For more performance but less compression (more storage used) and less +# CPU usage, use LZ4. See https://github.com/facebook/rocksdb/wiki/Compression for more details. # # "none" will disable compression. # -# Defaults to "zstd" #rocksdb_compression_algo = "zstd" -# Level of compression the specified compression algorithm for RocksDB to use. -# Default is 32767, which is internally read by RocksDB as the default magic number and -# translated to the library's default compression level as they all differ. +# Level of compression the specified compression algorithm for RocksDB to +# use. +# +# Default is 32767, which is internally read by RocksDB as the +# default magic number and translated to the library's default +# compression level as they all differ. # See their `kDefaultCompressionLevel`. # #rocksdb_compression_level = 32767 -# Level of compression the specified compression algorithm for the bottommost level/data for RocksDB to use. -# Default is 32767, which is internally read by RocksDB as the default magic number and -# translated to the library's default compression level as they all differ. +# Level of compression the specified compression algorithm for the +# bottommost level/data for RocksDB to use. Default is 32767, which is +# internally read by RocksDB as the default magic number and translated +# to the library's default compression level as they all differ. # See their `kDefaultCompressionLevel`. # -# Since this is the bottommost level (generally old and least used data), it may be desirable to have a very -# high compression level here as it's lesss likely for this data to be used. Research your chosen compression algorithm. +# Since this is the bottommost level (generally old and least used data), +# it may be desirable to have a very high compression level here as it's +# lesss likely for this data to be used. Research your chosen compression +# algorithm. # #rocksdb_bottommost_compression_level = 32767 -# Whether to enable RocksDB "bottommost_compression". -# At the expense of more CPU usage, this will further compress the database to reduce more storage. -# It is recommended to use ZSTD compression with this for best compression results. +# Whether to enable RocksDB's "bottommost_compression". +# +# At the expense of more CPU usage, this will further compress the +# database to reduce more storage. It is recommended to use ZSTD +# compression with this for best compression results. This may be useful +# if you're trying to reduce storage usage from the database. +# # See https://github.com/facebook/rocksdb/wiki/Compression for more details. # -# Defaults to false as this uses more CPU when compressing. #rocksdb_bottommost_compression = false -# Level of statistics collection. Some admin commands to display database statistics may require -# this option to be set. Database performance may be impacted by higher settings. +# Database recovery mode (for RocksDB WAL corruption) # -# Option is a number ranging from 0 to 6: -# 0 = No statistics. -# 1 = No statistics in release mode (default). -# 2 to 3 = Statistics with no performance impact. -# 3 to 5 = Statistics with possible performance impact. -# 6 = All statistics. +# Use this option when the server reports corruption and refuses to start. +# Set mode 2 (PointInTime) to cleanly recover from this corruption. The +# server will continue from the last good state, several seconds or +# minutes prior to the crash. Clients may have to run "clear-cache & +# reload" to account for the rollback. Upon success, you may reset the +# mode back to default and restart again. Please note in some cases the +# corruption error may not be cleared for at least 30 minutes of +# operation in PointInTime mode. # -# Defaults to 1 (No statistics, except in debug-mode) -#rocksdb_stats_level = 1 +# As a very last ditch effort, if PointInTime does not fix or resolve +# anything, you can try mode 3 (SkipAnyCorruptedRecord) but this will +# leave the server in a potentially inconsistent state. +# +# The default mode 1 (TolerateCorruptedTailRecords) will automatically +# drop the last entry in the database if corrupted during shutdown, but +# nothing more. It is extraordinarily unlikely this will desynchronize +# clients. To disable any form of silent rollback set mode 0 +# (AbsoluteConsistency). +# +# The options are: +# 0 = AbsoluteConsistency +# 1 = TolerateCorruptedTailRecords (default) +# 2 = PointInTime (use me if trying to recover) +# 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) +# +# See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information on these modes. +# +# See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. +# +#rocksdb_recovery_mode = 1 # Database repair mode (for RocksDB SST corruption) # -# Use this option when the server reports corruption while running or panics. If the server refuses -# to start use the recovery mode options first. Corruption errors containing the acronym 'SST' which -# occur after startup will likely require this option. +# Use this option when the server reports corruption while running or +# panics. If the server refuses to start use the recovery mode options +# first. Corruption errors containing the acronym 'SST' which occur after +# startup will likely require this option. # -# - Backing up your database directory is recommended prior to running the repair. -# - Disabling repair mode and restarting the server is recommended after running the repair. +# - Backing up your database directory is recommended prior to running the +# repair. +# - Disabling repair mode and restarting the server is recommended after +# running the repair. +# +# See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. # -# Defaults to false #rocksdb_repair = false -# Database recovery mode (for RocksDB WAL corruption) +# This item is undocumented. Please contribute documentation for it. # -# Use this option when the server reports corruption and refuses to start. Set mode 2 (PointInTime) -# to cleanly recover from this corruption. The server will continue from the last good state, -# several seconds or minutes prior to the crash. Clients may have to run "clear-cache & reload" to -# account for the rollback. Upon success, you may reset the mode back to default and restart again. -# Please note in some cases the corruption error may not be cleared for at least 30 minutes of -# operation in PointInTime mode. +#rocksdb_read_only = false + +# This item is undocumented. Please contribute documentation for it. # -# As a very last ditch effort, if PointInTime does not fix or resolve anything, you can try mode -# 3 (SkipAnyCorruptedRecord) but this will leave the server in a potentially inconsistent state. +#rocksdb_secondary = false + +# Enables idle CPU priority for compaction thread. This is not enabled by +# default to prevent compaction from falling too far behind on busy +# systems. # -# The default mode 1 (TolerateCorruptedTailRecords) will automatically drop the last entry in the -# database if corrupted during shutdown, but nothing more. It is extraordinarily unlikely this will -# desynchronize clients. To disable any form of silent rollback set mode 0 (AbsoluteConsistency). +#rocksdb_compaction_prio_idle = false + +# Enables idle IO priority for compaction thread. This prevents any +# unexpected lag in the server's operation and is usually a good idea. +# Enabled by default. # -# The options are: -# 0 = AbsoluteConsistency -# 1 = TolerateCorruptedTailRecords (default) -# 2 = PointInTime (use me if trying to recover) -# 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) +#rocksdb_compaction_ioprio_idle = true + +# Config option to disable RocksDB compaction. You should never ever have +# to disable this. If you for some reason find yourself needing to disable +# this as part of troubleshooting or a bug, please reach out to us in the +# conduwuit Matrix room with information and details. # -# See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information +# Disabling compaction will lead to a significantly bloated and +# explosively large database, gradually poor performance, unnecessarily +# excessive disk read/writes, and slower shutdowns and startups. # -# Defaults to 1 (TolerateCorruptedTailRecords) -#rocksdb_recovery_mode = 1 +#rocksdb_compaction = true +# Level of statistics collection. Some admin commands to display database +# statistics may require this option to be set. Database performance may +# be impacted by higher settings. +# +# Option is a number ranging from 0 to 6: +# 0 = No statistics. +# 1 = No statistics in release mode (default). +# 2 to 3 = Statistics with no performance impact. +# 3 to 5 = Statistics with possible performance impact. +# 6 = All statistics. +# +#rocksdb_stats_level = 1 -### Domain Name Resolution and Caching +# This is a password that can be configured that will let you login to the +# server bot account (currently `@conduit`) for emergency troubleshooting +# purposes such as recovering/recreating your admin room, or inviting +# yourself back. +# +# See https://conduwuit.puppyirl.gay/troubleshooting.html#lost-access-to-admin-room for other ways to get back into your admin room. +# +# Once this password is unset, all sessions will be logged out for +# security purposes. +# +# example: "F670$2CP@Hw8mG7RY1$%!#Ic7YA" +# +#emergency_password = -# Maximum entries stored in DNS memory-cache. The size of an entry may vary so please take care if -# raising this value excessively. Only decrease this when using an external DNS cache. Please note -# that systemd does *not* count as an external cache, even when configured to do so. -#dns_cache_entries = 32768 +# This item is undocumented. Please contribute documentation for it. +# +#notification_push_path = "/_matrix/push/v1/notify" -# Minimum time-to-live in seconds for entries in the DNS cache. The default may appear high to most -# administrators; this is by design. Only decrease this if you are using an external DNS cache. -#dns_min_ttl = 10800 +# Config option to control local (your server only) presence +# updates/requests. Note that presence on conduwuit is +# very fast unlike Synapse's. If using outgoing presence, this MUST be +# enabled. +# +#allow_local_presence = true -# Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. This value is critical for -# the server to federate efficiently. NXDOMAIN's are assumed to not be returning to the federation -# and aggressively cached rather than constantly rechecked. +# Config option to control incoming federated presence updates/requests. # -# Defaults to 3 days as these are *very rarely* false negatives. -#dns_min_ttl_nxdomain = 259200 +# This option receives presence updates from other +# servers, but does not send any unless `allow_outgoing_presence` is true. +# Note that presence on conduwuit is very fast unlike Synapse's. +# +#allow_incoming_presence = true -# The number of seconds to wait for a reply to a DNS query. Please note that recursive queries can -# take up to several seconds for some domains, so this value should not be too low. -#dns_timeout = 10 +# Config option to control outgoing presence updates/requests. +# +# This option sends presence updates to other servers, but does not +# receive any unless `allow_incoming_presence` is true. +# Note that presence on conduwuit is very fast unlike Synapse's. +# If using outgoing presence, you MUST enable `allow_local_presence` as +# well. +# +#allow_outgoing_presence = true -# Number of retries after a timeout. -#dns_attempts = 10 +# Config option to control how many seconds before presence updates that +# you are idle. Defaults to 5 minutes. +# +#presence_idle_timeout_s = 300 -# Fallback to TCP on DNS errors. Set this to false if unsupported by nameserver. -#dns_tcp_fallback = true +# Config option to control how many seconds before presence updates that +# you are offline. Defaults to 30 minutes. +# +#presence_offline_timeout_s = 1800 -# Enable to query all nameservers until the domain is found. Referred to as "trust_negative_responses" in hickory_resolver. -# This can avoid useless DNS queries if the first nameserver responds with NXDOMAIN or an empty NOERROR response. +# Config option to enable the presence idle timer for remote users. +# Disabling is offered as an optimization for servers participating in +# many large rooms or when resources are limited. Disabling it may cause +# incorrect presence states (i.e. stuck online) to be seen for some +# remote users. # -# The default is to query one nameserver and stop (false). -#query_all_nameservers = true +#presence_timeout_remote_users = true -# Enables using *only* TCP for querying your specified nameservers instead of UDP. +# Config option to control whether we should receive remote incoming read +# receipts. # -# You very likely do *not* want this. hickory-resolver already falls back to TCP on UDP errors. -# Defaults to false -#query_over_tcp_only = false +#allow_incoming_read_receipts = true -# DNS A/AAAA record lookup strategy +# Config option to control whether we should send read receipts to remote +# servers. # -# Takes a number of one of the following options: -# 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) -# 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) -# 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever returns a successful response first) -# 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A record) -# 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA record) +#allow_outgoing_read_receipts = true + +# Config option to control outgoing typing updates to federation. # -# If you don't have IPv6 networking, then for better performance it may be suitable to set this to Ipv4Only (1) as -# you will never ever use the AAAA record contents even if the AAAA record is successful instead of the A record. +#allow_outgoing_typing = true + +# Config option to control incoming typing updates from federation. # -# Defaults to 5 - Ipv4ThenIpv6 as this is the most compatible and IPv4 networking is currently the most prevalent. -#ip_lookup_strategy = 5 +#allow_incoming_typing = true +# Config option to control maximum time federation user can indicate +# typing. +# +#typing_federation_timeout_s = 30 -### Request Timeouts, Connection Timeouts, and Connection Pooling +# Config option to control minimum time local client can indicate typing. +# This does not override a client's request to stop typing. It only +# enforces a minimum value in case of no stop request. +# +#typing_client_timeout_min_s = 15 -## Request Timeouts are HTTP response timeouts -## Connection Timeouts are TCP connection timeouts -## -## Connection Pooling Timeouts are timeouts for keeping an open idle connection alive. -## Connection pooling and keepalive is very useful for federation or other places where for performance reasons, -## we want to keep connections open that we will re-use frequently due to TCP and TLS 1.3 overhead/expensiveness. -## -## Generally these defaults are the best, but if you find a reason to need to change these they are here. +# Config option to control maximum time local client can indicate typing. +# +#typing_client_timeout_max_s = 45 -# Default/base connection timeout. -# This is used only by URL previews and update/news endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# zstd. This option does nothing if conduwuit was not built with +# `zstd_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH +# before deciding to enable this. # -# Defaults to 10 seconds -#request_conn_timeout = 10 +#zstd_compression = false -# Default/base request timeout. The time waiting to receive more data from another server. -# This is used only by URL previews, update/news, and misc endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# gzip. This option does nothing if conduwuit was not built with +# `gzip_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before +# deciding to enable this. # -# Defaults to 35 seconds -#request_timeout = 35 +# If you are in a large amount of rooms, you may find that enabling this +# is necessary to reduce the significantly large response bodies. +# +#gzip_compression = false -# Default/base request total timeout. The time limit for a whole request. This is set very high to not -# cancel healthy requests while serving as a backstop. -# This is used only by URL previews and update/news endpoint checks +# Set this to true for conduwuit to compress HTTP response bodies using +# brotli. This option does nothing if conduwuit was not built with +# `brotli_compression` feature. Please be aware that enabling HTTP +# compression may weaken TLS. Most users should not need to enable this. +# See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before +# deciding to enable this. # -# Defaults to 320 seconds -#request_total_timeout = 320 +#brotli_compression = false -# Default/base idle connection pool timeout -# This is used only by URL previews and update/news endpoint checks +# Set to true to allow user type "guest" registrations. Some clients like +# Element attempt to register guest users automatically. # -# Defaults to 5 seconds -#request_idle_timeout = 5 +#allow_guest_registration = false -# Default/base max idle connections per host -# This is used only by URL previews and update/news endpoint checks +# Set to true to log guest registrations in the admin room. Note that +# these may be noisy or unnecessary if you're a public homeserver. # -# Defaults to 1 as generally the same open connection can be re-used -#request_idle_per_host = 1 +#log_guest_registrations = false -# Federation well-known resolution connection timeout +# Set to true to allow guest registrations/users to auto join any rooms +# specified in `auto_join_rooms`. # -# Defaults to 6 seconds -#well_known_conn_timeout = 6 +#allow_guests_auto_join_rooms = false -# Federation HTTP well-known resolution request timeout +# Config option to control whether the legacy unauthenticated Matrix media +# repository endpoints will be enabled. These endpoints consist of: +# - /_matrix/media/*/config +# - /_matrix/media/*/upload +# - /_matrix/media/*/preview_url +# - /_matrix/media/*/download/* +# - /_matrix/media/*/thumbnail/* # -# Defaults to 10 seconds -#well_known_timeout = 10 +# The authenticated equivalent endpoints are always enabled. +# +# Defaults to true for now, but this is highly subject to change, likely +# in the next release. +# +#allow_legacy_media = true -# Federation client request timeout -# You most definitely want this to be high to account for extremely large room joins, slow homeservers, your own resources etc. +# This item is undocumented. Please contribute documentation for it. # -# Defaults to 300 seconds -#federation_timeout = 300 +#freeze_legacy_media = true -# Federation client idle connection pool timeout +# Checks consistency of the media directory at startup: +# 1. When `media_compat_file_link` is enbled, this check will upgrade +# media when switching back and forth between Conduit and conduwuit. +# Both options must be enabled to handle this. +# 2. When media is deleted from the directory, this check will also delete +# its database entry. # -# Defaults to 25 seconds -#federation_idle_timeout = 25 +# If none of these checks apply to your use cases, and your media +# directory is significantly large setting this to false may reduce +# startup time. +# +#media_startup_check = true -# Federation client max idle connections per host +# Enable backward-compatibility with Conduit's media directory by creating +# symlinks of media. This option is only necessary if you plan on using +# Conduit again. Otherwise setting this to false reduces filesystem +# clutter and overhead for managing these symlinks in the directory. This +# is now disabled by default. You may still return to upstream Conduit +# but you have to run conduwuit at least once with this set to true and +# allow the media_startup_check to take place before shutting +# down to return to Conduit. # -# Defaults to 1 as generally the same open connection can be re-used -#federation_idle_per_host = 1 +#media_compat_file_link = false -# Federation sender request timeout -# The time it takes for the remote server to process sent transactions can take a while. +# Prunes missing media from the database as part of the media startup +# checks. This means if you delete files from the media directory the +# corresponding entries will be removed from the database. This is +# disabled by default because if the media directory is accidentally moved +# or inaccessible, the metadata entries in the database will be lost with +# sadness. # -# Defaults to 180 seconds -#sender_timeout = 180 +#prune_missing_media = false -# Federation sender idle connection pool timeout +# Vector list of servers that conduwuit will refuse to download remote +# media from. # -# Defaults to 180 seconds -#sender_idle_timeout = 180 +#prevent_media_downloads_from = [] -# Federation sender transaction retry backoff limit +# List of forbidden server names that we will block incoming AND outgoing +# federation with, and block client room joins / remote user invites. # -# Defaults to 86400 seconds -#sender_retry_backoff_limit = 86400 +# This check is applied on the room ID, room alias, sender server name, +# sender user's server name, inbound federation X-Matrix origin, and +# outbound federation handler. +# +# Basically "global" ACLs. +# +#forbidden_remote_server_names = [] -# Appservice URL request connection timeout +# List of forbidden server names that we will block all outgoing federated +# room directory requests for. Useful for preventing our users from +# wandering into bad servers or spaces. # -# Defaults to 35 seconds as generally appservices are hosted within the same network -#appservice_timeout = 35 +#forbidden_remote_room_directory_server_names = [] -# Appservice URL idle connection pool timeout +# Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you +# do not want conduwuit to send outbound requests to. Defaults to +# RFC1918, unroutable, loopback, multicast, and testnet addresses for +# security. # -# Defaults to 300 seconds -#appservice_idle_timeout = 300 +# Please be aware that this is *not* a guarantee. You should be using a +# firewall with zones as doing this on the application layer may have +# bypasses. +# +# Currently this does not account for proxies in use like Synapse does. +# +# To disable, set this to be an empty vector (`[]`). +# +# "192.168.0.0/16", "100.64.0.0/10", "192.0.0.0/24", "169.254.0.0/16", +# "192.88.99.0/24", "198.18.0.0/15", "192.0.2.0/24", "198.51.100.0/24", +# "203.0.113.0/24", "224.0.0.0/4", "::1/128", "fe80::/10", "fc00::/7", +# "2001:db8::/32", "ff00::/8", "fec0::/10"] +# +#ip_range_denylist = ["127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", -# Notification gateway pusher idle connection pool timeout +# Vector list of domains allowed to send requests to for URL previews. +# Defaults to none. Note: this is a *contains* match, not an explicit +# match. Putting "google.com" will match "https://google.com" and +# "http://mymaliciousdomainexamplegoogle.com" Setting this to "*" will +# allow all URL previews. Please note that this opens up significant +# attack surface to your server, you are expected to be aware of the +# risks by doing so. # -# Defaults to 15 seconds -#pusher_idle_timeout = 15 +#url_preview_domain_contains_allowlist = [] +# Vector list of explicit domains allowed to send requests to for URL +# previews. Defaults to none. Note: This is an *explicit* match, not a +# contains match. Putting "google.com" will match "https://google.com", +# "http://google.com", but not +# "https://mymaliciousdomainexamplegoogle.com". Setting this to "*" will +# allow all URL previews. Please note that this opens up significant +# attack surface to your server, you are expected to be aware of the +# risks by doing so. +# +#url_preview_domain_explicit_allowlist = [] -### Presence / Typing Indicators / Read Receipts +# Vector list of explicit domains not allowed to send requests to for URL +# previews. Defaults to none. Note: This is an *explicit* match, not a +# contains match. Putting "google.com" will match "https://google.com", +# "http://google.com", but not +# "https://mymaliciousdomainexamplegoogle.com". The denylist is checked +# first before allowlist. Setting this to "*" will not do anything. +# +#url_preview_domain_explicit_denylist = [] -# Config option to control local (your server only) presence updates/requests. Defaults to true. -# Note that presence on conduwuit is very fast unlike Synapse's. -# If using outgoing presence, this MUST be enabled. +# Vector list of URLs allowed to send requests to for URL previews. +# Defaults to none. Note that this is a *contains* match, not an +# explicit match. Putting "google.com" will match +# "https://google.com/", +# "https://google.com/url?q=https://mymaliciousdomainexample.com", and +# "https://mymaliciousdomainexample.com/hi/google.com" Setting this to +# "*" will allow all URL previews. Please note that this opens up +# significant attack surface to your server, you are expected to be +# aware of the risks by doing so. # -#allow_local_presence = true +#url_preview_url_contains_allowlist = [] -# Config option to control incoming federated presence updates/requests. Defaults to true. -# This option receives presence updates from other servers, but does not send any unless `allow_outgoing_presence` is true. -# Note that presence on conduwuit is very fast unlike Synapse's. +# Maximum amount of bytes allowed in a URL preview body size when +# spidering. Defaults to 384KB in bytes. # -#allow_incoming_presence = true +#url_preview_max_spider_size = 384000 -# Config option to control outgoing presence updates/requests. Defaults to true. -# This option sends presence updates to other servers, but does not receive any unless `allow_incoming_presence` is true. -# Note that presence on conduwuit is very fast unlike Synapse's. -# If using outgoing presence, you MUST enable `allow_local_presence` as well. +# Option to decide whether you would like to run the domain allowlist +# checks (contains and explicit) on the root domain or not. Does not apply +# to URL contains allowlist. Defaults to false. # -#allow_outgoing_presence = true +# Example usecase: If this is +# enabled and you have "wikipedia.org" allowed in the explicit and/or +# contains domain allowlist, it will allow all subdomains under +# "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is +# checked and matched. Useful if the domain contains allowlist is still +# too broad for you but you still want to allow all the subdomains under a +# root domain. +# +#url_preview_check_root_domain = false -# Config option to enable the presence idle timer for remote users. Disabling is offered as an optimization for -# servers participating in many large rooms or when resources are limited. Disabling it may cause incorrect -# presence states (i.e. stuck online) to be seen for some remote users. Defaults to true. -#presence_timeout_remote_users = true +# List of forbidden room aliases and room IDs as strings of regex +# patterns. +# +# Regex can be used or explicit contains matches can be done by +# just specifying the words (see example). +# +# This is checked upon room alias creation, custom room ID creation if +# used, and startup as warnings if any room aliases in your database have +# a forbidden room alias/ID. +# +# example: ["19dollarfortnitecards", "b[4a]droom"] +# +#forbidden_alias_names = [] -# Config option to control how many seconds before presence updates that you are idle. Defaults to 5 minutes. -#presence_idle_timeout_s = 300 +# List of forbidden username patterns/strings. +# +# Regex can be used or explicit contains matches can be done by just +# specifying the words (see example). +# +# This is checked upon username availability check, registration, and +# startup as warnings if any local users in your database have a forbidden +# username. +# +# example: ["administrator", "b[a4]dusernam[3e]"] +# +#forbidden_usernames = [] -# Config option to control how many seconds before presence updates that you are offline. Defaults to 30 minutes. -#presence_offline_timeout_s = 1800 +# Retry failed and incomplete messages to remote servers immediately upon +# startup. This is called bursting. If this is disabled, said messages +# may not be delivered until more messages are queued for that server. Do +# not change this option unless server resources are extremely limited or +# the scale of the server's deployment is huge. Do not disable this +# unless you know what you are doing. +# +#startup_netburst = true -# Config option to control whether we should receive remote incoming read receipts. -# Defaults to true. -#allow_incoming_read_receipts = true +# messages are dropped and not reattempted. The `startup_netburst` option +# must be enabled for this value to have any effect. Do not change this +# value unless you know what you are doing. Set this value to -1 to +# reattempt every message without trimming the queues; this may consume +# significant disk. Set this value to 0 to drop all messages without any +# attempt at redelivery. +# +#startup_netburst_keep = 50 -# Config option to control whether we should send read receipts to remote servers. -# Defaults to true. -#allow_outgoing_read_receipts = true +# controls whether non-admin local users are forbidden from sending room +# invites (local and remote), and if non-admin users can receive remote +# room invites. admins are always allowed to send and receive all room +# invites. +# +#block_non_admin_invites = false -# Config option to control outgoing typing updates to federation. Defaults to true. -#allow_outgoing_typing = true +# Allows admins to enter commands in rooms other than "#admins" (admin +# room) by prefixing your message with "\!admin" or "\\!admin" followed +# up a normal conduwuit admin command. The reply will be publicly visible +# to the room, originating from the sender. +# +# example: \\!admin debug ping puppygock.gay +# +#admin_escape_commands = true -# Config option to control incoming typing updates from federation. Defaults to true. -#allow_incoming_typing = true +# Controls whether the conduwuit admin room console / CLI will immediately +# activate on startup. This option can also be enabled with `--console` +# conduwuit argument. +# +#admin_console_automatic = false -# Config option to control maximum time federation user can indicate typing. -#typing_federation_timeout_s = 30 +# Controls what admin commands will be executed on startup. This is a +# vector list of strings of admin commands to run. +# +# +# This option can also be configured with the `--execute` conduwuit +# argument and can take standard shell commands and environment variables +# +# Such example could be: `./conduwuit --execute "server admin-notice +# conduwuit has started up at $(date)"` +# +# example: admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` +# +#admin_execute = [] -# Config option to control minimum time local client can indicate typing. This does not override -# a client's request to stop typing. It only enforces a minimum value in case of no stop request. -#typing_client_timeout_min_s = 15 +# Controls whether conduwuit should error and fail to start if an admin +# execute command (`--execute` / `admin_execute`) fails. +# +#admin_execute_errors_ignore = false -# Config option to control maximum time local client can indicate typing. -#typing_client_timeout_max_s = 45 +# Controls the max log level for admin command log captures (logs +# generated from running admin commands). Defaults to "info" on release +# builds, else "debug" on debug builds. +# +#admin_log_capture = "info" +# The default room tag to apply on the admin room. +# +# On some clients like Element, the room tag "m.server_notice" is a +# special pinned room at the very bottom of your room list. The conduwuit +# admin room can be pinned here so you always have an easy-to-access +# shortcut dedicated to your admin room. +# +#admin_room_tag = "m.server_notice" -### TURN / VoIP +# Sentry.io crash/panic reporting, performance monitoring/metrics, etc. +# This is NOT enabled by default. conduwuit's default Sentry reporting +# endpoint is o4506996327251968.ingest.us.sentry.io +# +#sentry = false -# vector list of TURN URIs/servers to use +# Sentry reporting URL if a custom one is desired # -# replace "example.turn.uri" with your TURN domain, such as the coturn "realm". -# if using TURN over TLS, replace "turn:" with "turns:" +#sentry_endpoint = "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" + +# Report your conduwuit server_name in Sentry.io crash reports and metrics # -# No default -#turn_uris = ["turn:example.turn.uri?transport=udp", "turn:example.turn.uri?transport=tcp"] +#sentry_send_server_name = false -# TURN secret to use that's read from the file path specified +# Performance monitoring/tracing sample rate for Sentry.io # -# this takes priority over "turn_secret" first, and falls back to "turn_secret" if invalid or -# failed to open. +# Note that too high values may impact performance, and can be disabled by +# setting it to 0.0 (0%) This value is read as a percentage to Sentry, +# represented as a decimal. Defaults to 15% of traces (0.15) # -# no default -#turn_secret_file = "/path/to/secret.txt" +#sentry_traces_sample_rate = 0.15 -# TURN secret to use for generating the HMAC-SHA1 hash apart of username and password generation +# Whether to attach a stacktrace to Sentry reports. # -# this is more secure, but if needed you can use traditional username/password below. +#sentry_attach_stacktrace = false + +# Send panics to sentry. This is true by default, but sentry has to be +# enabled. The global "sentry" config option must be enabled to send any +# data. # -# no default -#turn_secret = "" +#sentry_send_panic = true -# TURN username to provide the client +# Send errors to sentry. This is true by default, but sentry has to be +# enabled. This option is only effective in release-mode; forced to false +# in debug-mode. # -# no default -#turn_username = "" +#sentry_send_error = true -# TURN password to provide the client +# Controls the tracing log level for Sentry to send things like +# breadcrumbs and transactions # -# no default -#turn_password = "" +#sentry_filter = "info" -# TURN TTL +# Enable the tokio-console. This option is only relevant to developers. +# See https://conduwuit.puppyirl.gay/development.html#debugging-with-tokio-console for more information. # -# Default is 86400 seconds -#turn_ttl = 86400 +#tokio_console = false -# allow guests/unauthenticated users to access TURN credentials +# This item is undocumented. Please contribute documentation for it. # -# this is the equivalent of Synapse's `turn_allow_guests` config option. this allows -# any unauthenticated user to call `/_matrix/client/v3/voip/turnServer`. +#test = false + +# Controls whether admin room notices like account registrations, password +# changes, account deactivations, room directory publications, etc will +# be sent to the admin room. Update notices and normal admin command +# responses will still be sent. # -# defaults to false -#turn_allow_guests = false +#admin_room_notices = true +[global.tls] -# Other options not in [global]: +# Path to a valid TLS certificate file. +# +# example: "/path/to/my/certificate.crt" # +#certs = + +# Path to a valid TLS certificate private key. # -# Enables running conduwuit with direct TLS support -# It is strongly recommended you use a reverse proxy instead. This is primarily relevant for test suites like complement that require a private CA setup. -# [global.tls] -# certs = "/path/to/my/certificate.crt" -# key = "/path/to/my/private_key.key" +# example: "/path/to/my/certificate.key" # +#key = + # Whether to listen and allow for HTTP and HTTPS connections (insecure!) -# This config option is only available if conduwuit was built with `axum_dual_protocol` feature (not default feature) -# Defaults to false +# #dual_protocol = false +[global.well_known] + +# The server base domain of the URL with a specific port that the server +# well-known file will serve. This should contain a port at the end, and +# should not be a URL. +# +# example: "matrix.example.com:443" +# +#server = -# If you are using delegation via well-known files and you cannot serve them from your reverse proxy, you can -# uncomment these to serve them directly from conduwuit. This requires proxying all requests to conduwuit, not just `/_matrix` to work. +# The server URL that the client well-known file will serve. This should +# not contain a port, and should just be a valid HTTPS URL. # -#[global.well_known] -#server = "matrix.example.com:443" -#client = "https://matrix.example.com" +# example: "https://matrix.example.com" # -# A single contact and/or support page for /.well-known/matrix/support -# All options here are strings. Currently only supports 1 single contact. -# No default. +#client = + +# This item is undocumented. Please contribute documentation for it. +# +#support_page = + +# This item is undocumented. Please contribute documentation for it. +# +#support_role = + +# This item is undocumented. Please contribute documentation for it. +# +#support_email = + +# This item is undocumented. Please contribute documentation for it. # -#support_page = "" -#support_role = "" -#support_email = "" -#support_mxid = "" +#support_mxid = diff --git a/deps/rust-rocksdb/Cargo.toml b/deps/rust-rocksdb/Cargo.toml index 8c168b24f80ecd696e6f4b5b9e9b1c279e494e07..908a2911c45a4fabeac52816ca0ec1710070a0d6 100644 --- a/deps/rust-rocksdb/Cargo.toml +++ b/deps/rust-rocksdb/Cargo.toml @@ -27,7 +27,7 @@ malloc-usable-size = ["rust-rocksdb/malloc-usable-size"] [dependencies.rust-rocksdb] git = "https://github.com/girlbossceo/rust-rocksdb-zaidoon1" -rev = "c1e5523eae095a893deaf9056128c7dbc2d5fd73" +rev = "2bc5495a9f8f75073390c326b47ee5928ab7c7f0" #branch = "master" default-features = false diff --git a/docs/deploying/docker-compose.for-traefik.yml b/docs/deploying/docker-compose.for-traefik.yml index 1c615673accdd0e2de00454c0add579693f70745..b43164269b1de0de637aacae19034893f0a243b9 100644 --- a/docs/deploying/docker-compose.for-traefik.yml +++ b/docs/deploying/docker-compose.for-traefik.yml @@ -14,9 +14,8 @@ services: environment: CONDUWUIT_SERVER_NAME: your.server.name.example # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 # should match the loadbalancer traefik label - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/docs/deploying/docker-compose.with-caddy.yml b/docs/deploying/docker-compose.with-caddy.yml index 899f4d679afc0e7117daa0fb2945b1cd803889e3..c080293f05879ddfd7a901bcdf23ad07f737a1ad 100644 --- a/docs/deploying/docker-compose.with-caddy.yml +++ b/docs/deploying/docker-compose.with-caddy.yml @@ -30,9 +30,8 @@ services: environment: CONDUWUIT_SERVER_NAME: example.com # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/docs/deploying/docker-compose.with-traefik.yml b/docs/deploying/docker-compose.with-traefik.yml index f05006a55746253682103d3c78fdac438942ac18..89118c74267edde375cd389cd57765e14fd7c65a 100644 --- a/docs/deploying/docker-compose.with-traefik.yml +++ b/docs/deploying/docker-compose.with-traefik.yml @@ -15,7 +15,8 @@ services: CONDUWUIT_SERVER_NAME: your.server.name.example # EDIT THIS CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' CONDUWUIT_ALLOW_REGISTRATION: 'false' # After setting a secure registration token, you can enable this - CONDUWUIT_REGISTRATION_TOKEN: # This is a token you can use to register on the server + CONDUWUIT_REGISTRATION_TOKEN: "" # This is a token you can use to register on the server + #CONDUWUIT_REGISTRATION_TOKEN_FILE: "" # Alternatively you can configure a path to a token file to read CONDUWUIT_ADDRESS: 0.0.0.0 CONDUWUIT_PORT: 6167 # you need to match this with the traefik load balancer label if you're want to change it CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit @@ -23,7 +24,6 @@ services: ### Uncomment and change values as desired, note that conduwuit has plenty of config options, so you should check out the example example config too # Available levels are: error, warn, info, debug, trace - more info at: https://docs.rs/env_logger/*/env_logger/#enabling-logging # CONDUWUIT_LOG: info # default is: "warn,state_res=warn" - # CONDUWUIT_ALLOW_JAEGER: 'false' # CONDUWUIT_ALLOW_ENCRYPTION: 'true' # CONDUWUIT_ALLOW_FEDERATION: 'true' # CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' @@ -31,7 +31,7 @@ services: # CONDUWUIT_ALLOW_OUTGOING_PRESENCE: true # CONDUWUIT_ALLOW_LOCAL_PRESENCE: true # CONDUWUIT_WORKERS: 10 - # CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + # CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB # CONDUWUIT_NEW_USER_DISPLAYNAME_SUFFIX = "ðŸ³<200d>⚧" # We need some way to serve the client and server .well-known json. The simplest way is via the CONDUWUIT_WELL_KNOWN diff --git a/docs/deploying/docker-compose.yml b/docs/deploying/docker-compose.yml index bc9f24777a56c40442aa3a284d30c0f5ee2e7dbf..3b7d84ed1f9d587025fe6e9e0892b6e313adb350 100644 --- a/docs/deploying/docker-compose.yml +++ b/docs/deploying/docker-compose.yml @@ -14,9 +14,8 @@ services: environment: CONDUWUIT_SERVER_NAME: your.server.name # EDIT THIS CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit - CONDUWUIT_DATABASE_BACKEND: rocksdb CONDUWUIT_PORT: 6167 - CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB CONDUWUIT_ALLOW_REGISTRATION: 'true' CONDUWUIT_ALLOW_FEDERATION: 'true' CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' diff --git a/docs/deploying/docker.md b/docs/deploying/docker.md index 7b8fd1a2cc47cb8c6587f42e9f88491ce5cd464d..e9c49c71663a997ef86a5403e49158a13d36774c 100644 --- a/docs/deploying/docker.md +++ b/docs/deploying/docker.md @@ -40,7 +40,6 @@ ### Run docker run -d -p 8448:6167 \ -v db:/var/lib/conduwuit/ \ -e CONDUWUIT_SERVER_NAME="your.server.name" \ - -e CONDUWUIT_DATABASE_BACKEND="rocksdb" \ -e CONDUWUIT_ALLOW_REGISTRATION=false \ --name conduit $LINK ``` diff --git a/docs/deploying/freebsd.md b/docs/deploying/freebsd.md index 4ac83515b8da92a81abb3ac0952843b48dd687b4..65b40204b7ab1ccb776e2825ff670b39f1f9c73c 100644 --- a/docs/deploying/freebsd.md +++ b/docs/deploying/freebsd.md @@ -1,11 +1,5 @@ # conduwuit for FreeBSD -conduwuit at the moment does not provide FreeBSD builds. Building conduwuit on -FreeBSD requires a specific environment variable to use the system prebuilt -RocksDB library instead of rust-rocksdb / rust-librocksdb-sys which does *not* -work and will cause a build error or coredump. +conduwuit at the moment does not provide FreeBSD builds or have FreeBSD packaging, however conduwuit does build and work on FreeBSD using the system-provided RocksDB. -Use the following environment variable: `ROCKSDB_LIB_DIR=/usr/local/lib` - -Such example commandline with it can be: `ROCKSDB_LIB_DIR=/usr/local/lib cargo -build --release` +Contributions for getting conduwuit packaged are welcome. diff --git a/docs/deploying/generic.md b/docs/deploying/generic.md index 1e44ab541a84fb21cdd35f4cbb46426672a0e024..f0b85a25cab11a3e143887d4180c40e438146ebb 100644 --- a/docs/deploying/generic.md +++ b/docs/deploying/generic.md @@ -42,6 +42,9 @@ ## Migrating from Conduit this will **NOT** work on conduwuit and you must configure delegation manually. This is not a mistake and no support for this feature will be added. +If you are using SQLite, you **MUST** migrate to RocksDB. You can use this +tool to migrate from SQLite to RocksDB: <https://github.com/ShadowJonathan/conduit_toolbox/> + See the `[global.well_known]` config section, or configure your web server appropriately to send the delegation responses. @@ -65,13 +68,25 @@ ## Adding a conduwuit user ## Forwarding ports in the firewall or the router -conduwuit uses the ports 443 and 8448 both of which need to be open in the -firewall. +Matrix's default federation port is port 8448, and clients must be using port 443. +If you would like to use only port 443, or a different port, you will need to setup +delegation. conduwuit has config options for doing delegation, or you can configure +your reverse proxy to manually serve the necessary JSON files to do delegation +(see the `[global.well_known]` config section). If conduwuit runs behind a router or in a container and has a different public IP address than the host system these public ports need to be forwarded directly or indirectly to the port mentioned in the config. +Note for NAT users; if you have trouble connecting to your server from the inside +of your network, you need to research your router and see if it supports "NAT +hairpinning" or "NAT loopback". + +If your router does not support this feature, you need to research doing local +DNS overrides and force your Matrix DNS records to use your local IP internally. +This can be done at the host level using `/etc/hosts`. If you need this to be +on the network level, consider something like NextDNS or Pi-Hole. + ## Setting up a systemd service The systemd unit for conduwuit can be found @@ -119,12 +134,16 @@ ## Setting up the Reverse Proxy (handles TLS, reverse proxy headers, etc transparently with proper defaults). Lighttpd is not supported as it seems to mess with the `X-Matrix` Authorization -header, making federation non-functional. If using Apache, you need to use -`nocanon` in your `ProxyPass` directive to prevent this (note that Apache -isn't very good as a general reverse proxy). +header, making federation non-functional. If a workaround is found, feel free to share to get it added to the documentation here. + +If using Apache, you need to use `nocanon` in your `ProxyPass` directive to prevent this (note that Apache isn't very good as a general reverse proxy and we discourage the usage of it if you can). + +If using Nginx, you need to give conduwuit the request URI using `$request_uri`, or like so: +- `proxy_pass http://127.0.0.1:6167$request_uri;` +- `proxy_pass http://127.0.0.1:6167;` Nginx users may need to set `proxy_buffering off;` if there are issues with -uploading media like images. +uploading media like images. This is due to Nginx storing the entire POST content in-memory (`/tmp`) and running out of memory if on low memory hardware. You will need to reverse proxy everything under following routes: - `/_matrix/` - core Matrix C-S and S-S APIs @@ -133,11 +152,20 @@ ## Setting up the Reverse Proxy You can optionally reverse proxy the following individual routes: - `/.well-known/matrix/client` and `/.well-known/matrix/server` if using -conduwuit to perform delegation +conduwuit to perform delegation (see the `[global.well_known]` config section) - `/.well-known/matrix/support` if using conduwuit to send the homeserver admin contact and support page (formerly known as MSC1929) - `/` if you would like to see `hewwo from conduwuit woof!` at the root +See the following spec pages for more details on these files: +- [`/.well-known/matrix/server`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixserver) +- [`/.well-known/matrix/client`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixclient) +- [`/.well-known/matrix/support`](https://spec.matrix.org/latest/client-server-api/#getwell-knownmatrixsupport) + +Examples of delegation: +- <https://puppygock.gay/.well-known/matrix/server> +- <https://puppygock.gay/.well-known/matrix/client> + ### Caddy Create `/etc/caddy/conf.d/conduwuit_caddyfile` and enter this (substitute for diff --git a/docs/deploying/kubernetes.md b/docs/deploying/kubernetes.md new file mode 100644 index 0000000000000000000000000000000000000000..2a1bcb51ab2818f81d21361f1ee7bf5218b26d6e --- /dev/null +++ b/docs/deploying/kubernetes.md @@ -0,0 +1,4 @@ +# conduwuit for Kubernetes + +conduwuit doesn't support horizontal scalability or distributed loading natively, however a community maintained Helm Chart is available here to run conduwuit on Kubernetes: +<https://gitlab.cronce.io/charts/conduwuit> diff --git a/docs/deploying/nixos.md b/docs/deploying/nixos.md index 9147db7f3f74068ef9818d87bdaea277387a65bb..61fb391631f037424b3a3575f659c0c5cd3423eb 100644 --- a/docs/deploying/nixos.md +++ b/docs/deploying/nixos.md @@ -39,6 +39,15 @@ ### NixOS module welcome!), so [`services.matrix-conduit`][module] from Nixpkgs can be used to configure conduwuit. +### Conduit NixOS Config Module and SQLite + +Beware! The [`services.matrix-conduit`][module] module defaults to SQLite as a database backend. +Conduwuit dropped SQLite support in favor of exclusively supporting the much faster RocksDB. +Make sure that you are using the RocksDB backend before migrating! + +There is a [tool to migrate a Conduit SQLite database to +RocksDB](https://github.com/ShadowJonathan/conduit_toolbox/). + If you want to run the latest code, you should get conduwuit from the `flake.nix` or `default.nix` and set [`services.matrix-conduit.package`][package] appropriately to use conduwuit instead of Conduit. diff --git a/docs/differences.md b/docs/differences.md index 6815d248574863d2f18676d5dd167a45cf372b1a..18ea7a1ffca798dbf571ac7505ddf8d10c293b2e 100644 --- a/docs/differences.md +++ b/docs/differences.md @@ -241,8 +241,7 @@ ## Maintenance/Stability - Fixed every single clippy (default lints) and rustc warnings, including some that were performance related or potential safety issues / unsoundness - Add a **lot** of other clippy and rustc lints and a rustfmt.toml file -- Repo uses [Renovate](https://docs.renovatebot.com/), -[Trivy](https://github.com/aquasecurity/trivy-action), and keeps ALL +- Repo uses [Renovate](https://docs.renovatebot.com/) and keeps ALL dependencies as up to date as possible - Purge unmaintained/irrelevant/broken database backends (heed, sled, persy) and other unnecessary code or overhead diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index c1499f3a1312d9ef053a153aa9d1310d65991321..74e19de762bcb2dafb6405b47795610ba5d67245 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -47,10 +47,11 @@ #### Direct IO Some filesystems may not like RocksDB using [Direct IO](https://github.com/facebook/rocksdb/wiki/Direct-IO). Direct IO is for -non-buffered I/O which improves conduwuit performance, but at least FUSE is a -filesystem potentially known to not like this. See the [example -config](configuration/examples.md) for disabling it if needed. Issues from -Direct IO on unsupported filesystems are usually shown as startup errors. +non-buffered I/O which improves conduwuit performance and reduces system CPU +usage, but at least FUSE and possibly ZFS are filesystems potentially known +to not like this. See the [example config](configuration/examples.md) for +disabling it if needed. Issues from Direct IO on unsupported filesystems are +usually shown as startup errors. #### Database corruption diff --git a/flake.lock b/flake.lock index 271a215415dc668a52aabbcde3efb40bf694271c..7740e9254788c5e6502143c1d181ddc396c52f6e 100644 --- a/flake.lock +++ b/flake.lock @@ -922,16 +922,16 @@ "rocksdb": { "flake": false, "locked": { - "lastModified": 1729712930, - "narHash": "sha256-jlp4kPkRTpoJaUdobEoHd8rCGAQNBy4ZHZ6y5zL/ibw=", + "lastModified": 1731690620, + "narHash": "sha256-Xd4TJYqPERMJLXaGa6r6Ny1Wlw8Uy5Cyf/8q7nS58QM=", "owner": "girlbossceo", "repo": "rocksdb", - "rev": "871eda6953c3f399aae39808dcfccdd014885beb", + "rev": "292446aa2bc41699204d817a1e4b091679a886eb", "type": "github" }, "original": { "owner": "girlbossceo", - "ref": "v9.7.3", + "ref": "v9.7.4", "repo": "rocksdb", "type": "github" } diff --git a/flake.nix b/flake.nix index 85b7baa0e2dc940dddf90cf61854cf8f108cef40..113757a73d47ffe0694234570b33b07935008095 100644 --- a/flake.nix +++ b/flake.nix @@ -9,7 +9,7 @@ flake-utils.url = "github:numtide/flake-utils?ref=main"; nix-filter.url = "github:numtide/nix-filter?ref=main"; nixpkgs.url = "github:NixOS/nixpkgs?ref=nixpkgs-unstable"; - rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.7.3"; flake = false; }; + rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.7.4"; flake = false; }; liburing = { url = "github:axboe/liburing?ref=master"; flake = false; }; }; diff --git a/nix/pkgs/complement/default.nix b/nix/pkgs/complement/default.nix index 80e9ce2733ef13c4ce1b6e575654ff9a7c0df5d3..36f124001fb3fd2390bb425d949c40bfd6e30f03 100644 --- a/nix/pkgs/complement/default.nix +++ b/nix/pkgs/complement/default.nix @@ -25,6 +25,7 @@ let "tokio_console" # sentry telemetry isn't useful for complement, disabled by default anyways "sentry_telemetry" + "perf_measurements" # the containers don't use or need systemd signal support "systemd" # this is non-functional on nix for some reason @@ -96,6 +97,7 @@ dockerTools.buildImage { Env = [ "SSL_CERT_FILE=/complement/ca/ca.crt" "CONDUWUIT_CONFIG=${./config.toml}" + "RUST_BACKTRACE=full" ]; ExposedPorts = { diff --git a/nix/pkgs/main/default.nix b/nix/pkgs/main/default.nix index 1088b03cd414d19532b7dd111a6325107219ef90..d11c36cc5d88cb9d8c7ef36ae2f090ff2e069119 100644 --- a/nix/pkgs/main/default.nix +++ b/nix/pkgs/main/default.nix @@ -176,7 +176,7 @@ commonAttrs = { # # <https://github.com/input-output-hk/haskell.nix/issues/829> postInstall = with pkgsBuildHost; '' - find "$out" -type f -exec remove-references-to -t ${stdenv.cc} -t ${gcc} -t ${rustc.unwrapped} -t ${rustc} -t ${libidn2} -t ${libunistring} '{}' + + find "$out" -type f -exec remove-references-to -t ${stdenv.cc} -t ${gcc} -t ${libgcc} -t ${llvm} -t ${libllvm} -t ${rustc.unwrapped} -t ${rustc} -t ${libidn2} -t ${libunistring} '{}' + ''; }; in diff --git a/nix/pkgs/oci-image/default.nix b/nix/pkgs/oci-image/default.nix index 5078523bcc01382a4411348309dabad68d4749ce..9b6413106e89ff2b0cc23b5668e8c3eac2fb7a1a 100644 --- a/nix/pkgs/oci-image/default.nix +++ b/nix/pkgs/oci-image/default.nix @@ -24,5 +24,8 @@ dockerTools.buildLayeredImage { Cmd = [ "${lib.getExe main}" ]; + Env = [ + "RUST_BACKTRACE=full" + ]; }; } diff --git a/rustfmt.toml b/rustfmt.toml index 114677d4998533e6c617d2f8990fd61823a7ef09..fd912a1932d1fe6444667b7ad6faeec8cfc54214 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,28 +1,27 @@ -edition = "2021" - +array_width = 80 +chain_width = 60 +comment_width = 80 condense_wildcard_suffixes = true +edition = "2021" +fn_call_width = 80 +fn_params_layout = "Compressed" +fn_single_line = true format_code_in_doc_comments = true format_macro_bodies = true format_macro_matchers = true format_strings = true -hex_literal_case = "Upper" -max_width = 120 -tab_spaces = 4 -array_width = 80 -comment_width = 80 -wrap_comments = true -fn_params_layout = "Compressed" -fn_call_width = 80 -fn_single_line = true +group_imports = "StdExternalCrate" hard_tabs = true -match_block_trailing_comma = true +hex_literal_case = "Upper" imports_granularity = "Crate" +match_block_trailing_comma = true +max_width = 120 +newline_style = "Unix" normalize_comments = false reorder_impl_items = true reorder_imports = true -group_imports = "StdExternalCrate" -newline_style = "Unix" +tab_spaces = 4 use_field_init_shorthand = true use_small_heuristics = "Off" use_try_shorthand = true -chain_width = 60 +wrap_comments = true diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index d756b3cbdd171a1637f638997debd0ad8bcb6f77..f5cab4496285c66f51003f76852a6dc4d85882d9 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,10 +29,11 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true +conduit-database.workspace = true conduit-macros.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true log.workspace = true ruma.workspace = true serde_json.workspace = true diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index 0a983046447b417310a41f688fed462974b99bbd..88fca462fb2046e26741a01f9a9095659a3e2547 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -1,5 +1,6 @@ use conduit::Result; use conduit_macros::implement; +use futures::StreamExt; use ruma::events::room::message::RoomMessageEventContent; use crate::Command; @@ -10,14 +11,12 @@ #[implement(Command, params = "<'_>")] pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> { let timer = tokio::time::Instant::now(); - let results = self.services.users.db.iter(); + let users = self.services.users.iter().collect::<Vec<_>>().await; let query_time = timer.elapsed(); - let users = results.collect::<Vec<_>>(); - let total = users.len(); - let err_count = users.iter().filter(|user| user.is_err()).count(); - let ok_count = users.iter().filter(|user| user.is_ok()).count(); + let err_count = users.iter().filter(|_user| false).count(); + let ok_count = users.iter().filter(|_user| true).count(); let message = format!( "Database query completed in {query_time:?}:\n\n```\nTotal entries: {total:?}\nFailure/Invalid user count: \ diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 2d967006494f77762a704d353429e58e3d75e9da..f9d4a521fce2d34562f8d3e3bbbfc85ebb753d8e 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -1,18 +1,18 @@ use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Write, sync::Arc, time::{Instant, SystemTime}, }; -use api::client::validate_and_add_event_id; -use conduit::{debug, debug_error, err, info, trace, utils, warn, Error, PduEvent, Result}; +use conduit::{debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, Result}; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use tokio::sync::RwLock; +use service::rooms::state_compressor::HashSetCompressStateEvent; use tracing_subscriber::EnvFilter; use crate::admin_command; @@ -26,37 +26,39 @@ pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEvent #[admin_command] pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { - let event_id = Arc::<EventId>::from(event_id); - if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - let start = Instant::now(); - let count = self - .services - .rooms - .auth_chain - .event_ids_iter(room_id, vec![event_id]) - .await? - .count(); + let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else { + return Ok(RoomMessageEventContent::notice_plain("Event not found.")); + }; - let elapsed = start.elapsed(); - Ok(RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {count} in {elapsed:?}" - ))) - } else { - Ok(RoomMessageEventContent::text_plain("Event not found.")) - } + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + let start = Instant::now(); + let count = self + .services + .rooms + .auth_chain + .event_ids_iter(room_id, &[&event_id]) + .await? + .count() + .await; + + let elapsed = start.elapsed(); + Ok(RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + ))) } #[admin_command] pub(super) async fn parse_pdu(&self) -> Result<RoomMessageEventContent> { - if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + if self.body.len() < 2 + || !self.body[0].trim().starts_with("```") + || self.body.last().unwrap_or(&EMPTY).trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -91,25 +93,28 @@ pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessage .services .rooms .timeline - .get_non_outlier_pdu_json(&event_id)?; - if pdu_json.is_none() { + .get_non_outlier_pdu_json(&event_id) + .await; + + if pdu_json.is_err() { outlier = true; - pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id)?; + pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id).await; } + match pdu_json { - Some(json) => { + Ok(json) => { let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); Ok(RoomMessageEventContent::notice_markdown(format!( "{}\n```json\n{}\n```", if outlier { - "Outlier PDU found in our database" + "Outlier (Rejected / Soft Failed) PDU found in our database" } else { "PDU found in our database" }, json_text ))) }, - None => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), + Err(_) => Ok(RoomMessageEventContent::text_plain("PDU not found locally.")), } } @@ -130,7 +135,9 @@ pub(super) async fn get_remote_pdu_list( )); } - if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + if self.body.len() < 2 + || !self.body[0].trim().starts_with("```") + || self.body.last().unwrap_or(&EMPTY).trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -157,7 +164,8 @@ pub(super) async fn get_remote_pdu_list( .send_message(RoomMessageEventContent::text_plain(format!( "Failed to get remote PDU, ignoring error: {e}" ))) - .await; + .await + .ok(); warn!("Failed to get remote PDU, ignoring error: {e}"); } else { success_count = success_count.saturating_add(1); @@ -196,6 +204,7 @@ pub(super) async fn get_remote_pdu( &server, ruma::api::federation::event::get_event::v1::Request { event_id: event_id.clone().into(), + include_unredacted_content: None, }, ) .await @@ -210,12 +219,14 @@ pub(super) async fn get_remote_pdu( })?; trace!("Attempting to parse PDU: {:?}", &response.pdu); - let parsed_pdu = { + let _parsed_pdu = { let parsed_result = self .services .rooms .event_handler - .parse_incoming_pdu(&response.pdu); + .parse_incoming_pdu(&response.pdu) + .await; + let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -230,22 +241,12 @@ pub(super) async fn get_remote_pdu( vec![(event_id, value, room_id)] }; - let pub_key_map = RwLock::new(BTreeMap::new()); - - debug!("Attempting to fetch homeserver signing keys for {server}"); - self.services - .server_keys - .fetch_required_signing_keys(parsed_pdu.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {server}: {e:?}"); - }); - info!("Attempting to handle event ID {event_id} as backfilled PDU"); self.services .rooms .timeline - .backfill_pdu(&server, response.pdu, &pub_key_map) + .backfill_pdu(&server, response.pdu) + .boxed() .await?; let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); @@ -333,9 +334,12 @@ pub(super) async fn ping(&self, server: Box<ServerName>) -> Result<RoomMessageEv #[admin_command] pub(super) async fn force_device_list_updates(&self) -> Result<RoomMessageEventContent> { // Force E2EE device list updates for all users - for user_id in self.services.users.iter().filter_map(Result::ok) { - self.services.users.mark_device_key_update(&user_id)?; - } + self.services + .users + .stream() + .for_each(|user_id| self.services.users.mark_device_key_update(user_id)) + .await; + Ok(RoomMessageEventContent::text_plain( "Marked all devices for all users as having new keys to update", )) @@ -419,12 +423,10 @@ pub(super) async fn sign_json(&self) -> Result<RoomMessageEventContent> { let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); match serde_json::from_str(&string) { Ok(mut value) => { - ruma::signatures::sign_json( - self.services.globals.server_name().as_str(), - self.services.globals.keypair(), - &mut value, - ) - .expect("our request json is what ruma expects"); + self.services + .server_keys + .sign_json(&mut value) + .expect("our request json is what ruma expects"); let json_text = serde_json::to_string_pretty(&value).expect("canonical json is valid json"); Ok(RoomMessageEventContent::text_plain(json_text)) }, @@ -442,27 +444,31 @@ pub(super) async fn verify_json(&self) -> Result<RoomMessageEventContent> { } let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - let pub_key_map = RwLock::new(BTreeMap::new()); - - self.services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - - let pub_key_map = pub_key_map.read().await; - match ruma::signatures::verify_json(&pub_key_map, &value) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Signature correct")), - Err(e) => Ok(RoomMessageEventContent::text_plain(format!( - "Signature verification failed: {e}" - ))), - } + match serde_json::from_str::<CanonicalJsonObject>(&string) { + Ok(value) => match self.services.server_keys.verify_json(&value, None).await { + Ok(()) => Ok(RoomMessageEventContent::text_plain("Signature correct")), + Err(e) => Ok(RoomMessageEventContent::text_plain(format!( + "Signature verification failed: {e}" + ))), }, Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Invalid json: {e}"))), } } +#[admin_command] +pub(super) async fn verify_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { + let mut event = self.services.rooms.timeline.get_pdu_json(&event_id).await?; + + event.remove("event_id"); + let msg = match self.services.server_keys.verify_event(&event, None).await { + Ok(ruma::signatures::Verified::Signatures) => "signatures OK, but content hash failed (redaction).", + Ok(ruma::signatures::Verified::All) => "signatures and hashes OK.", + Err(e) => return Err(e), + }; + + Ok(RoomMessageEventContent::notice_plain(msg)) +} + #[admin_command] #[tracing::instrument(skip(self))] pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { @@ -470,7 +476,8 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo .services .rooms .state_cache - .server_in_room(&self.services.globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id) + .await { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", @@ -481,8 +488,9 @@ pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Roo .services .rooms .timeline - .first_pdu_in_room(&room_id)? - .ok_or_else(|| Error::bad_database("Failed to find the first PDU in database"))?; + .first_pdu_in_room(&room_id) + .await + .map_err(|_| Error::bad_database("Failed to find the first PDU in database"))?; Ok(RoomMessageEventContent::text_plain(format!("{first_pdu:?}"))) } @@ -494,7 +502,8 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro .services .rooms .state_cache - .server_in_room(&self.services.globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id) + .await { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", @@ -505,8 +514,9 @@ pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<Ro .services .rooms .timeline - .latest_pdu_in_room(&room_id)? - .ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?; + .latest_pdu_in_room(&room_id) + .await + .map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?; Ok(RoomMessageEventContent::text_plain(format!("{latest_pdu:?}"))) } @@ -520,7 +530,8 @@ pub(super) async fn force_set_room_state_from_server( .services .rooms .state_cache - .server_in_room(&self.services.globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id) + .await { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", @@ -531,13 +542,13 @@ pub(super) async fn force_set_room_state_from_server( .services .rooms .timeline - .latest_pdu_in_room(&room_id)? - .ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?; + .latest_pdu_in_room(&room_id) + .await + .map_err(|_| Error::bad_database("Failed to find the latest PDU in database"))?; - let room_version = self.services.rooms.state.get_room_version(&room_id)?; + let room_version = self.services.rooms.state.get_room_version(&room_id).await?; let mut state: HashMap<u64, Arc<EventId>> = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); let remote_state_response = self .services @@ -551,30 +562,28 @@ pub(super) async fn force_set_room_state_from_server( ) .await?; - let mut events = Vec::with_capacity(remote_state_response.pdus.len()); - for pdu in remote_state_response.pdus.clone() { - events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { + match self + .services + .rooms + .event_handler + .parse_incoming_pdu(&pdu) + .await + { Ok(t) => t, Err(e) => { warn!("Could not parse PDU, ignoring: {e}"); continue; }, - }); + }; } - info!("Fetching required signing keys for all the state events we got"); - self.services - .server_keys - .fetch_required_signing_keys(events.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await?; - info!("Going through room_state response PDUs"); - for result in remote_state_response - .pdus - .iter() - .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) - { + for result in remote_state_response.pdus.iter().map(|pdu| { + self.services + .server_keys + .validate_and_add_event_id(pdu, &room_version) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -587,23 +596,26 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { let shortstatekey = self .services .rooms .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; + state.insert(shortstatekey, pdu.event_id.clone()); } } info!("Going through auth_chain response"); - for result in remote_state_response - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) - { + for result in remote_state_response.auth_chain.iter().map(|pdu| { + self.services + .server_keys + .validate_and_add_event_id(pdu, &room_version) + }) { let Ok((event_id, value)) = result.await else { continue; }; @@ -611,7 +623,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .outlier - .add_pdu_outlier(&event_id, &value)?; + .add_pdu_outlier(&event_id, &value); } let new_room_state = self @@ -622,17 +634,22 @@ pub(super) async fn force_set_room_state_from_server( .await?; info!("Forcing new room state"); - let (short_state_hash, new, removed) = self + let HashSetCompressStateEvent { + shortstatehash: short_state_hash, + added, + removed, + } = self .services .rooms .state_compressor - .save_state(room_id.clone().as_ref(), new_room_state)?; + .save_state(room_id.clone().as_ref(), new_room_state) + .await?; let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; self.services .rooms .state - .force_state(room_id.clone().as_ref(), short_state_hash, new, removed, &state_lock) + .force_state(room_id.clone().as_ref(), short_state_hash, added, removed, &state_lock) .await?; info!( @@ -642,7 +659,8 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state_cache - .update_joined_count(&room_id)?; + .update_joined_count(&room_id) + .await; drop(state_lock); @@ -653,10 +671,33 @@ pub(super) async fn force_set_room_state_from_server( #[admin_command] pub(super) async fn get_signing_keys( - &self, server_name: Option<Box<ServerName>>, _cached: bool, + &self, server_name: Option<Box<ServerName>>, notary: Option<Box<ServerName>>, query: bool, ) -> Result<RoomMessageEventContent> { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let signing_keys = self.services.globals.signing_keys_for(&server_name)?; + + if let Some(notary) = notary { + let signing_keys = self + .services + .server_keys + .notary_request(¬ary, &server_name) + .await?; + + return Ok(RoomMessageEventContent::notice_markdown(format!( + "```rs\n{signing_keys:#?}\n```" + ))); + } + + let signing_keys = if query { + self.services + .server_keys + .server_request(&server_name) + .await? + } else { + self.services + .server_keys + .signing_keys_for(&server_name) + .await? + }; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" @@ -664,34 +705,20 @@ pub(super) async fn get_signing_keys( } #[admin_command] -#[allow(dead_code)] -pub(super) async fn get_verify_keys( - &self, server_name: Option<Box<ServerName>>, cached: bool, -) -> Result<RoomMessageEventContent> { +pub(super) async fn get_verify_keys(&self, server_name: Option<Box<ServerName>>) -> Result<RoomMessageEventContent> { let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); - let mut out = String::new(); - - if cached { - writeln!(out, "| Key ID | VerifyKey |")?; - writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { - writeln!(out, "| {key_id} | {verify_key:?} |")?; - } - return Ok(RoomMessageEventContent::notice_markdown(out)); - } - - let signature_ids: Vec<String> = Vec::new(); let keys = self .services .server_keys - .fetch_signing_keys_for_server(&server_name, signature_ids) - .await?; + .verify_keys_for(&server_name) + .await; + let mut out = String::new(); writeln!(out, "| Key ID | Public Key |")?; writeln!(out, "| --- | --- |")?; for (key_id, key) in keys { - writeln!(out, "| {key_id} | {key} |")?; + writeln!(out, "| {key_id} | {key:?} |")?; } Ok(RoomMessageEventContent::notice_markdown(out)) @@ -814,10 +841,10 @@ pub(super) async fn database_stats( &self, property: Option<String>, map: Option<String>, ) -> Result<RoomMessageEventContent> { let property = property.unwrap_or_else(|| "rocksdb.stats".to_owned()); - let map_name = map.as_ref().map_or(utils::string::EMPTY, String::as_str); + let map_name = map.as_ref().map_or(EMPTY, String::as_str); let mut out = String::new(); - for (name, map) in self.services.db.iter_maps() { + for (name, map) in self.services.db.iter() { if !map_name.is_empty() && *map_name != *name { continue; } diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 20ddbf2f6bfccccfc6f0aa2bec6b687efa8abced..b74e9c36cb35fb753541d6003463f174979bc841 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -80,8 +80,16 @@ pub(super) enum DebugCommand { GetSigningKeys { server_name: Option<Box<ServerName>>, + #[arg(long)] + notary: Option<Box<ServerName>>, + #[arg(short, long)] - cached: bool, + query: bool, + }, + + /// - Get and display signing keys from local cache or remote server. + GetVerifyKeys { + server_name: Option<Box<ServerName>>, }, /// - Sends a federation request to the remote server's @@ -119,6 +127,13 @@ pub(super) enum DebugCommand { /// the command. VerifyJson, + /// - Verify PDU + /// + /// This re-verifies a PDU existing in the database found by ID. + VerifyPdu { + event_id: Box<EventId>, + }, + /// - Prints the very first PDU in the specified room (typically /// m.room.create) FirstPduInRoom { diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 8917a46b944fd033988840d4a6258a04602136e2..0c9df43306b6846d0bb69ab713716475fbbb9b85 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,19 +1,20 @@ use std::fmt::Write; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; -use crate::{admin_command, escape_html, get_room_info}; +use crate::{admin_command, get_room_info}; #[admin_command] pub(super) async fn disable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); Ok(RoomMessageEventContent::text_plain("Room disabled.")) } #[admin_command] pub(super) async fn enable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); Ok(RoomMessageEventContent::text_plain("Room enabled.")) } @@ -85,7 +86,7 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result< )); } - if !self.services.users.exists(&user_id)? { + if !self.services.users.exists(&user_id).await { return Ok(RoomMessageEventContent::text_plain( "Remote user does not exist in our database.", )); @@ -96,9 +97,9 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result< .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .map(|room_id| get_room_info(self.services, &room_id)) - .collect(); + .then(|room_id| get_room_info(self.services, room_id)) + .collect() + .await; if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); @@ -107,33 +108,15 @@ pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result< rooms.sort_by_key(|r| r.1); rooms.reverse(); - let output_plain = format!( - "Rooms {user_id} shares with us ({}):\n{}", + let output = format!( + "Rooms {user_id} shares with us ({}):\n```\n{}\n```", rooms.len(), rooms .iter() - .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .map(|(id, members, name)| format!("{id} | Members: {members} | Name: {name}")) .collect::<Vec<_>>() .join("\n") ); - let output_html = format!( - "<table><caption>Rooms {user_id} shares with us \ - ({})</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>", - rooms.len(), - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>", - id, - members, - escape_html(name) - ) - .expect("should be able to write to string buffer"); - output - }) - ); - Ok(RoomMessageEventContent::text_html(output_plain, output_html)) + Ok(RoomMessageEventContent::text_markdown(output)) } diff --git a/src/admin/media/commands.rs b/src/admin/media/commands.rs index 3c4bf2ef83557b34288fd0ec87f5bb05a504bbd1..82ac162ebefca2c7893de9d806cfd76af6f0c259 100644 --- a/src/admin/media/commands.rs +++ b/src/admin/media/commands.rs @@ -36,7 +36,7 @@ pub(super) async fn delete( let mut mxc_urls = Vec::with_capacity(4); // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { + if let Ok(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id).await { if let Some(content_key) = event_json.get("content") { debug!("Event ID has \"content\"."); let content_obj = content_key.as_object(); @@ -300,7 +300,7 @@ pub(super) async fn delete_all_from_server( #[admin_command] pub(super) async fn get_file_info(&self, mxc: OwnedMxcUri) -> Result<RoomMessageEventContent> { let mxc: Mxc<'_> = mxc.as_str().try_into()?; - let metadata = self.services.media.get_metadata(&mxc); + let metadata = self.services.media.get_metadata(&mxc).await; Ok(RoomMessageEventContent::notice_markdown(format!("```\n{metadata:#?}\n```"))) } diff --git a/src/admin/processor.rs b/src/admin/processor.rs index 4f60f56e933b98acb01057bcac8c836668ba7e63..3c1895ffdce58023ff9ebef533d968658ad6cad4 100644 --- a/src/admin/processor.rs +++ b/src/admin/processor.rs @@ -17,7 +17,7 @@ utils::string::{collect_stream, common_prefix}, warn, Error, Result, }; -use futures_util::future::FutureExt; +use futures::future::FutureExt; use ruma::{ events::{ relation::InReplyTo, diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index e18c298a3d71319bb98ad33161db29c9d3d2e761..ea45eb16606a7695e382bbacbd77c1a1475d1a61 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -1,9 +1,6 @@ use clap::Subcommand; use conduit::Result; -use ruma::{ - events::{room::message::RoomMessageEventContent, RoomAccountDataEventType}, - RoomId, UserId, -}; +use ruma::{events::room::message::RoomMessageEventContent, RoomId, UserId}; use crate::Command; @@ -25,7 +22,7 @@ pub(crate) enum AccountDataCommand { /// Full user ID user_id: Box<UserId>, /// Account data event type - kind: RoomAccountDataEventType, + kind: String, /// Optional room ID of the account data room_id: Option<Box<RoomId>>, }, @@ -44,7 +41,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .changes_since(room_id.as_deref(), &user_id, since)?; + .changes_since(room_id.as_deref(), &user_id, since) + .await?; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -59,7 +57,8 @@ pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_ let timer = tokio::time::Instant::now(); let results = services .account_data - .get(room_id.as_deref(), &user_id, kind)?; + .get_raw(room_id.as_deref(), &user_id, &kind) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 683c228f74bea50f9c9558a68924993bc2ee863b..02e89e7a1ba9d6d19549d72c11f787d167279395 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -26,10 +26,8 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> appservice_id, } => { let timer = tokio::time::Instant::now(); - let results = services - .appservice - .db - .get_registration(appservice_id.as_ref()); + let results = services.appservice.get_registration(&appservice_id).await; + let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -38,7 +36,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> }, AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services.appservice.all(); + let results = services.appservice.all().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 5f271c2c424c1763fe8ff5dcc14c9eb855471b8e..837d34e6e0c94885395211c3bb1f4f2e02e64a80 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -13,8 +13,6 @@ pub(crate) enum GlobalsCommand { LastCheckForUpdatesId, - LoadKeypair, - /// - This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. SigningKeysFor { @@ -29,7 +27,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - match subcommand { GlobalsCommand::DatabaseVersion => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.database_version(); + let results = services.globals.db.database_version().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -47,16 +45,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - }, GlobalsCommand::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services.updates.last_check_for_updates_id(); - let query_time = timer.elapsed(); - - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" - ))) - }, - GlobalsCommand::LoadKeypair => { - let timer = tokio::time::Instant::now(); - let results = services.globals.db.load_keypair(); + let results = services.updates.last_check_for_updates_id().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -67,7 +56,7 @@ pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) - origin, } => { let timer = tokio::time::Instant::now(); - let results = services.globals.db.verify_keys_for(&origin); + let results = services.server_keys.verify_keys_for(&origin).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index 145ecd9b1535efc641d808066c4a061f37804835..0963429e803c59d22b893d032212f2d60b90da5b 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, UserId}; use crate::Command; @@ -30,7 +31,7 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.get_presence(&user_id)?; + let results = services.presence.db.get_presence(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -41,12 +42,16 @@ pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) since, } => { let timer = tokio::time::Instant::now(); - let results = services.presence.db.presence_since(since); - let presence_since: Vec<(_, _, _)> = results.collect(); + let results: Vec<(_, _, _)> = services + .presence + .presence_since(since) + .map(|(user_id, count, bytes)| (user_id.to_owned(), count, bytes.to_vec())) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{presence_since:#?}\n```" + "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, } diff --git a/src/admin/query/pusher.rs b/src/admin/query/pusher.rs index 637c57b65553f7573e602e6fa52a3e896d95dbc2..a1bd32f99bab39edc0bc6aab1ed36fd645d484ea 100644 --- a/src/admin/query/pusher.rs +++ b/src/admin/query/pusher.rs @@ -21,7 +21,7 @@ pub(super) async fn process(subcommand: PusherCommand, context: &Command<'_>) -> user_id, } => { let timer = tokio::time::Instant::now(); - let results = services.pusher.get_pushers(&user_id)?; + let results = services.pusher.get_pushers(&user_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index 1809e26a0f64fc7cae41fa6d2f2693520cd8c1b0..382e4a784fd0ab0db57885b4357410d060aaae39 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; use crate::Command; @@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) alias, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.resolve_local_alias(&alias); + let results = services.rooms.alias.resolve_local_alias(&alias).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -42,8 +43,13 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.local_aliases_for_room(&room_id); - let aliases: Vec<_> = results.collect(); + let aliases: Vec<_> = services + .rooms + .alias + .local_aliases_for_room(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -52,8 +58,13 @@ pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) }, RoomAliasCommand::AllLocalAliases => { let timer = tokio::time::Instant::now(); - let results = services.rooms.alias.all_local_aliases(); - let aliases: Vec<_> = results.collect(); + let aliases = services + .rooms + .alias + .all_local_aliases() + .map(|(room_id, alias)| (room_id.to_owned(), alias.to_owned())) + .collect::<Vec<_>>() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index 4215cf8d6934481816dcc66c080410705484248a..e32517fb1cf8f91b6a2e8b8d00d630a57e1cb573 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; use crate::Command; @@ -86,7 +87,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let result = services.rooms.state_cache.server_in_room(&server, &room_id); + let result = services + .rooms + .state_cache + .server_in_room(&server, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -97,7 +102,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.room_servers(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_servers(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -108,7 +119,13 @@ pub(super) async fn process( server, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.server_rooms(&server).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .server_rooms(&server) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -119,7 +136,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.room_members(&room_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .room_members(&room_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -134,7 +157,9 @@ pub(super) async fn process( .rooms .state_cache .local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -149,7 +174,9 @@ pub(super) async fn process( .rooms .state_cache .active_local_users_in_room(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -160,7 +187,7 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_joined_count(&room_id); + let results = services.rooms.state_cache.room_joined_count(&room_id).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -171,7 +198,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.room_invited_count(&room_id); + let results = services + .rooms + .state_cache + .room_invited_count(&room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -182,11 +213,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services + let results: Vec<_> = services .rooms .state_cache .room_useroncejoined(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -197,11 +230,13 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services + let results: Vec<_> = services .rooms .state_cache .room_members_invited(&room_id) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -216,7 +251,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_invite_count(&room_id, &user_id); + .get_invite_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -231,7 +267,8 @@ pub(super) async fn process( let results = services .rooms .state_cache - .get_left_count(&room_id, &user_id); + .get_left_count(&room_id, &user_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -242,7 +279,13 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.rooms_joined(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_joined(&user_id) + .map(ToOwned::to_owned) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -253,7 +296,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.rooms_invited(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_invited(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -264,7 +312,12 @@ pub(super) async fn process( user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services.rooms.state_cache.rooms_left(&user_id).collect(); + let results: Vec<_> = services + .rooms + .state_cache + .rooms_left(&user_id) + .collect() + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -276,7 +329,11 @@ pub(super) async fn process( room_id, } => { let timer = tokio::time::Instant::now(); - let results = services.rooms.state_cache.invite_state(&user_id, &room_id); + let results = services + .rooms + .state_cache + .invite_state(&user_id, &room_id) + .await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 6d54bddfd9bd510a217ae05fb369812d89b9769e..eaab1f5eea764c5794d7c1e70ceed78145c0c1c6 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduit::Result; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; use service::sending::Destination; @@ -68,7 +69,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - SendingCommand::ActiveRequests => { let timer = tokio::time::Instant::now(); let results = services.sending.db.active_requests(); - let active_requests: Result<Vec<(_, _, _)>> = results.collect(); + let active_requests = results.collect::<Vec<_>>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -133,7 +134,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let queued_requests = results.collect::<Result<Vec<(_, _)>>>(); + let queued_requests = results.collect::<Vec<_>>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -199,7 +200,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - }, }; - let active_requests = results.collect::<Result<Vec<(_, _)>>>(); + let active_requests = results.collect::<Vec<_>>().await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( @@ -210,7 +211,7 @@ pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) - server_name, } => { let timer = tokio::time::Instant::now(); - let results = services.sending.db.get_latest_educount(&server_name); + let results = services.sending.db.get_latest_educount(&server_name).await; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/users.rs b/src/admin/query/users.rs index fee12fbfcae692409672c58ca51240307bbdc5c2..0792e484019d2a9297f3d36bdcbbb620732c249b 100644 --- a/src/admin/query/users.rs +++ b/src/admin/query/users.rs @@ -1,29 +1,344 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::stream::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedDeviceId, OwnedRoomId, OwnedUserId}; -use crate::Command; +use crate::{admin_command, admin_command_dispatch}; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/users.rs pub(crate) enum UsersCommand { - Iter, + CountUsers, + + IterUsers, + + PasswordHash { + user_id: OwnedUserId, + }, + + ListDevices { + user_id: OwnedUserId, + }, + + ListDevicesMetadata { + user_id: OwnedUserId, + }, + + GetDeviceMetadata { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDevicesVersion { + user_id: OwnedUserId, + }, + + CountOneTimeKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetDeviceKeys { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetUserSigningKey { + user_id: OwnedUserId, + }, + + GetMasterKey { + user_id: OwnedUserId, + }, + + GetToDeviceEvents { + user_id: OwnedUserId, + device_id: OwnedDeviceId, + }, + + GetLatestBackup { + user_id: OwnedUserId, + }, + + GetLatestBackupVersion { + user_id: OwnedUserId, + }, + + GetBackupAlgorithm { + user_id: OwnedUserId, + version: String, + }, + + GetAllBackups { + user_id: OwnedUserId, + version: String, + }, + + GetRoomBackups { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + }, + + GetBackupSession { + user_id: OwnedUserId, + version: String, + room_id: OwnedRoomId, + session_id: String, + }, +} + +#[admin_command] +async fn get_backup_session( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, session_id: String, +) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_session(&user_id, &version, &room_id, &session_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_room_backups( + &self, user_id: OwnedUserId, version: String, room_id: OwnedRoomId, +) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_room(&user_id, &version, &room_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_all_backups(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_all(&user_id, &version).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_backup_algorithm(&self, user_id: OwnedUserId, version: String) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_backup(&user_id, &version) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .key_backups + .get_latest_backup_version(&user_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_latest_backup(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self.services.key_backups.get_latest_backup(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } -/// All the getters and iterators in key_value/users.rs -pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { - let services = context.services; +#[admin_command] +async fn iter_users(&self) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result: Vec<OwnedUserId> = self.services.users.stream().map(Into::into).collect().await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn count_users(&self) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self.services.users.count().await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn password_hash(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self.services.users.password_hash(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_device_ids(&user_id) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await; + + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn list_devices_metadata(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let devices = self + .services + .users + .all_devices_metadata(&user_id) + .collect::<Vec<_>>() + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{devices:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_metadata(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let device = self + .services + .users + .get_device_metadata(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn get_devices_version(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let device = self.services.users.get_devicelist_version(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{device:#?}\n```" + ))) +} + +#[admin_command] +async fn count_one_time_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .count_one_time_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_device_keys(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_device_keys(&user_id, &device_id) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_user_signing_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self.services.users.get_user_signing_key(&user_id).await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} + +#[admin_command] +async fn get_master_key(&self, user_id: OwnedUserId) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_master_key(None, &user_id, &|_| true) + .await; + let query_time = timer.elapsed(); + + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) +} - match subcommand { - UsersCommand::Iter => { - let timer = tokio::time::Instant::now(); - let results = services.users.db.iter(); - let users = results.collect::<Vec<_>>(); - let query_time = timer.elapsed(); +#[admin_command] +async fn get_to_device_events( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, +) -> Result<RoomMessageEventContent> { + let timer = tokio::time::Instant::now(); + let result = self + .services + .users + .get_to_device_events(&user_id, &device_id) + .collect::<Vec<_>>() + .await; + let query_time = timer.elapsed(); - Ok(RoomMessageEventContent::notice_markdown(format!( - "Query completed in {query_time:?}:\n\n```rs\n{users:#?}\n```" - ))) - }, - } + Ok(RoomMessageEventContent::notice_markdown(format!( + "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" + ))) } diff --git a/src/admin/room/alias.rs b/src/admin/room/alias.rs index 415b8a083207effe1aecf8aeff7fc7f4c825d4db..1ccde47dc9eeab3876c8c42f6507f8d822f60315 100644 --- a/src/admin/room/alias.rs +++ b/src/admin/room/alias.rs @@ -2,7 +2,8 @@ use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; use crate::{escape_html, Command}; @@ -66,8 +67,8 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> force, room_id, .. - } => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services + } => match (force, services.rooms.alias.resolve_local_alias(&room_alias).await) { + (true, Ok(id)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -77,10 +78,10 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> ))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( + (false, Ok(id)) => Ok(RoomMessageEventContent::text_plain(format!( "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" ))), - (_, Ok(None)) => match services + (_, Err(_)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -88,12 +89,11 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain("Successfully set alias")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - (_, Err(err)) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), }, RoomAliasCommand::Remove { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => match services + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => match services .rooms .alias .remove_alias(&room_alias, server_user) @@ -102,15 +102,13 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Ok(()) => Ok(RoomMessageEventContent::text_plain(format!("Removed alias from {id}"))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Failed to remove alias: {err}"))), }, - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::Which { .. - } => match services.rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), - Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), + } => match services.rooms.alias.resolve_local_alias(&room_alias).await { + Ok(id) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), + Err(_) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), }, RoomAliasCommand::List { .. @@ -121,67 +119,63 @@ pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> room_id, } => { if let Some(room_id) = room_id { - let aliases = services + let aliases: Vec<OwnedRoomAliasId> = services .rooms .alias .local_aliases_for_room(&room_id) - .collect::<Result<Vec<_>, _>>(); - match aliases { - Ok(aliases) => { - let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "- {alias}").expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "<li>{}</li>", escape_html(alias.as_ref())) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases for {room_id}:\n{plain_list}"); - let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>"); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))), - } + .map(Into::into) + .collect() + .await; + + let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "- {alias}").expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "<li>{}</li>", escape_html(alias.as_ref())) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases for {room_id}:\n{plain_list}"); + let html = format!("Aliases for {room_id}:\n<ul>{html_list}</ul>"); + Ok(RoomMessageEventContent::text_html(plain, html)) } else { let aliases = services .rooms .alias .all_local_aliases() - .collect::<Result<Vec<_>, _>>(); - match aliases { - Ok(aliases) => { - let server_name = services.globals.server_name(); - let plain_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!(output, "- `{alias}` -> #{id}:{server_name}") - .expect("should be able to write to string buffer"); - output - }); - - let html_list = aliases - .iter() - .fold(String::new(), |mut output, (alias, id)| { - writeln!( - output, - "<li><code>{}</code> -> #{}:{}</li>", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .expect("should be able to write to string buffer"); - output - }); - - let plain = format!("Aliases:\n{plain_list}"); - let html = format!("Aliases:\n<ul>{html_list}</ul>"); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))), - } + .map(|(room_id, localpart)| (room_id.into(), localpart.into())) + .collect::<Vec<(OwnedRoomId, String)>>() + .await; + + let server_name = services.globals.server_name(); + let plain_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{alias}` -> #{id}:{server_name}") + .expect("should be able to write to string buffer"); + output + }); + + let html_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "<li><code>{}</code> -> #{}:{}</li>", + escape_html(alias.as_ref()), + escape_html(id), + server_name + ) + .expect("should be able to write to string buffer"); + output + }); + + let plain = format!("Aliases:\n{plain_list}"); + let html = format!("Aliases:\n<ul>{html_list}</ul>"); + Ok(RoomMessageEventContent::text_html(plain, html)) } }, } diff --git a/src/admin/room/commands.rs b/src/admin/room/commands.rs index 2adfa7d73205fbfadaf21acd1b7189f3da73d174..35e40c8be2495d085023d47338bec837e92a9712 100644 --- a/src/admin/room/commands.rs +++ b/src/admin/room/commands.rs @@ -1,5 +1,6 @@ use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use crate::{admin_command, get_room_info, PAGE_SIZE}; @@ -14,37 +15,16 @@ pub(super) async fn list_rooms( .rooms .metadata .iter_ids() - .filter_map(|room_id| { - room_id - .ok() - .filter(|room_id| { - if exclude_disabled - && self - .services - .rooms - .metadata - .is_disabled(room_id) - .unwrap_or(false) - { - return false; - } - - if exclude_banned - && self - .services - .rooms - .metadata - .is_banned(room_id) - .unwrap_or(false) - { - return false; - } - - true - }) - .map(|room_id| get_room_info(self.services, &room_id)) + .filter_map(|room_id| async move { + (!exclude_disabled || !self.services.rooms.metadata.is_disabled(room_id).await).then_some(room_id) }) - .collect::<Vec<_>>(); + .filter_map(|room_id| async move { + (!exclude_banned || !self.services.rooms.metadata.is_banned(room_id).await).then_some(room_id) + }) + .then(|room_id| get_room_info(self.services, room_id)) + .collect::<Vec<_>>() + .await; + rooms.sort_by_key(|r| r.1); rooms.reverse(); @@ -74,3 +54,10 @@ pub(super) async fn list_rooms( Ok(RoomMessageEventContent::notice_markdown(output_plain)) } + +#[admin_command] +pub(super) async fn exists(&self, room_id: OwnedRoomId) -> Result<RoomMessageEventContent> { + let result = self.services.rooms.metadata.exists(&room_id).await; + + Ok(RoomMessageEventContent::notice_markdown(format!("{result}"))) +} diff --git a/src/admin/room/directory.rs b/src/admin/room/directory.rs index 7bba2eb7b6a0d67b30d95c946b9ce205df4d42ee..0bdaf56d7123dc6d7523531c99cf82d8a0e27fdd 100644 --- a/src/admin/room/directory.rs +++ b/src/admin/room/directory.rs @@ -1,10 +1,9 @@ -use std::fmt::Write; - use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; +use futures::StreamExt; +use ruma::{events::room::message::RoomMessageEventContent, RoomId}; -use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; +use crate::{get_room_info, Command, PAGE_SIZE}; #[derive(Debug, Subcommand)] pub(crate) enum RoomDirectoryCommand { @@ -31,67 +30,51 @@ pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_> match command { RoomDirectoryCommand::Publish { room_id, - } => match services.rooms.directory.set_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room published")) }, RoomDirectoryCommand::Unpublish { room_id, - } => match services.rooms.directory.set_not_public(&room_id) { - Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), - Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), + } => { + services.rooms.directory.set_not_public(&room_id); + Ok(RoomMessageEventContent::notice_plain("Room unpublished")) }, RoomDirectoryCommand::List { page, } => { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); - let mut rooms = services + let mut rooms: Vec<_> = services .rooms .directory .public_rooms() - .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(services, &id)) - .collect::<Vec<_>>(); + .then(|room_id| get_room_info(services, room_id)) + .collect() + .await; + rooms.sort_by_key(|r| r.1); rooms.reverse(); - let rooms = rooms + let rooms: Vec<_> = rooms .into_iter() .skip(page.saturating_sub(1).saturating_mul(PAGE_SIZE)) .take(PAGE_SIZE) - .collect::<Vec<_>>(); + .collect(); if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("No more rooms.")); }; - let output_plain = format!( - "Rooms:\n{}", + let output = format!( + "Rooms (page {page}):\n```\n{}\n```", rooms .iter() - .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .map(|(id, members, name)| format!("{id} | Members: {members} | Name: {name}")) .collect::<Vec<_>>() .join("\n") ); - let output_html = format!( - "<table><caption>Room directory - page \ - {page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>", - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>", - escape_html(id.as_ref()), - members, - escape_html(name.as_ref()) - ) - .expect("should be able to write to string buffer"); - output - }) - ); - Ok(RoomMessageEventContent::text_html(output_plain, output_html)) + Ok(RoomMessageEventContent::text_markdown(output)) }, } } diff --git a/src/admin/room/info.rs b/src/admin/room/info.rs index d17a292477aeeef1d84a2f6c127aa286e22f0a07..13a74a9d3828752ac5ae56fd9d989cd536a57dbe 100644 --- a/src/admin/room/info.rs +++ b/src/admin/room/info.rs @@ -1,5 +1,6 @@ use clap::Subcommand; -use conduit::Result; +use conduit::{utils::ReadyExt, Result}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; use crate::{admin_command, admin_command_dispatch}; @@ -32,46 +33,40 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R .rooms .state_accessor .get_name(&room_id) - .ok() - .flatten() - .unwrap_or_else(|| room_id.to_string()); + .await + .unwrap_or_else(|_| room_id.to_string()); - let members = self + let member_info: Vec<_> = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|member| { - if local_only { - member - .ok() - .filter(|user| self.services.globals.user_is_local(user)) - } else { - member.ok() - } - }); - - let member_info = members - .into_iter() - .map(|user_id| { - ( - user_id.clone(), + .ready_filter(|user_id| { + local_only + .then(|| self.services.globals.user_is_local(user_id)) + .unwrap_or(true) + }) + .map(ToOwned::to_owned) + .filter_map(|user_id| async move { + Some(( self.services .users .displayname(&user_id) - .unwrap_or(None) - .unwrap_or_else(|| user_id.to_string()), - ) + .await + .unwrap_or_else(|_| user_id.to_string()), + user_id, + )) }) - .collect::<Vec<_>>(); + .collect() + .await; let output_plain = format!( "{} Members in Room \"{}\":\n```\n{}\n```", member_info.len(), room_name, member_info - .iter() - .map(|(mxid, displayname)| format!("{mxid} | {displayname}")) + .into_iter() + .map(|(displayname, mxid)| format!("{mxid} | {displayname}")) .collect::<Vec<_>>() .join("\n") ); @@ -81,11 +76,12 @@ async fn list_joined_members(&self, room_id: Box<RoomId>, local_only: bool) -> R #[admin_command] async fn view_room_topic(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - let Some(room_topic) = self + let Ok(room_topic) = self .services .rooms .state_accessor - .get_room_topic(&room_id)? + .get_room_topic(&room_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); }; diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 64d2af45296eeda65a95a5d6509fcca801e82e8b..8c6cbeaae6e29e46dfd9df56566549a61309f3f9 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -6,6 +6,7 @@ use clap::Subcommand; use conduit::Result; +use ruma::OwnedRoomId; use self::{ alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, @@ -49,4 +50,9 @@ pub(super) enum RoomCommand { #[command(subcommand)] /// - Manage the room directory Directory(RoomDirectoryCommand), + + /// - Check if we know about a room + Exists { + room_id: OwnedRoomId, + }, } diff --git a/src/admin/room/moderation.rs b/src/admin/room/moderation.rs index 70d8486b4b5fd013d8d527199b75362708ab6539..cfc048bdde811187de11f363b571382fbbc02567 100644 --- a/src/admin/room/moderation.rs +++ b/src/admin/room/moderation.rs @@ -1,6 +1,11 @@ use api::client::leave_room; use clap::Subcommand; -use conduit::{debug, error, info, warn, Result}; +use conduit::{ + debug, error, info, + utils::{IterStream, ReadyExt}, + warn, Result, +}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use crate::{admin_command, admin_command_dispatch, get_room_info}; @@ -76,7 +81,7 @@ async fn ban_room( let admin_room_alias = &self.services.globals.admin_alias; - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -95,7 +100,7 @@ async fn ban_room( debug!("Room specified is a room ID, banning room ID"); - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else if room.is_room_alias_id() { @@ -114,7 +119,13 @@ async fn ban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -138,7 +149,7 @@ async fn ban_room( } }; - self.services.rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true); room_id } else { @@ -150,56 +161,40 @@ async fn ban_room( debug!("Making all users leave the room {}", &room); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - self.services.globals.user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would - // fail auth check) - && (self.services.globals.user_is_local(local_user) - // since this is a force operation, assume user is an admin - // if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -214,12 +209,14 @@ async fn ban_room( } // remove any local aliases, ignore errors - for ref local_alias in self + for local_alias in &self .services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await { _ = self .services @@ -230,10 +227,10 @@ async fn ban_room( } // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); return Ok(RoomMessageEventContent::text_plain( "Room banned, removed all our local users, and disabled incoming federation with room.", )); @@ -268,7 +265,7 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + if let Ok(admin_room_id) = self.services.admin.get_admin_room().await { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; @@ -300,43 +297,48 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu if room_alias_or_id.is_room_alias_id() { match RoomAliasId::parse(room_alias_or_id) { Ok(room_alias) => { - let room_id = - if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room \ - ID over federation" - ); - - match self - .services - .rooms - .alias - .resolve_alias(&room_alias, None) - .await - { - Ok((room_id, servers)) => { - debug!( - ?room_id, - ?servers, - "Got federation response fetching room ID for {room}", - ); - room_id - }, - Err(e) => { - // don't fail if force blocking - if force { - warn!("Failed to resolve room alias {room} to a room ID: {e}"); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { + room_id + } else { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch room ID \ + over federation" + ); + + match self + .services + .rooms + .alias + .resolve_alias(&room_alias, None) + .await + { + Ok((room_id, servers)) => { + debug!( + ?room_id, + ?servers, + "Got federation response fetching room ID for {room}", + ); + room_id + }, + Err(e) => { + // don't fail if force blocking + if force { + warn!("Failed to resolve room alias {room} to a room ID: {e}"); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; room_ids.push(room_id); }, @@ -374,74 +376,52 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } for room_id in room_ids { - if self - .services - .rooms - .metadata - .ban_room(&room_id, true) - .is_ok() - { - debug!("Banned {room_id} successfully"); - room_ban_count = room_ban_count.saturating_add(1); - } + self.services.rooms.metadata.ban_room(&room_id, true); + + debug!("Banned {room_id} successfully"); + room_ban_count = room_ban_count.saturating_add(1); debug!("Making all users leave the room {}", &room_id); if force { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - // since this is a force operation, assume user is an - // admin if somehow this fails - && self.services - .users - .is_admin(local_user) - .unwrap_or(true)) - }) - }) { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, room_id + "Attempting leave for user {local_user} in room {room_id} (forced, ignoring all errors, evicting \ + admins too)", ); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in self + let mut users = self .services .rooms .state_cache .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == self.services.globals.server_name() - // additional wrapped check here is to avoid adding remote - // users who are in the admin room to the list of local - // users (would fail auth check) - && (local_user.server_name() - == self.services.globals.server_name() - && !self.services - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) { - debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { + .ready_filter(|user| self.services.globals.user_is_local(user)) + .boxed(); + + while let Some(local_user) = users.next().await { + if self.services.users.is_admin(local_user).await { + continue; + } + + debug!("Attempting leave for user {local_user} in room {room_id}"); + if let Err(e) = leave_room(self.services, local_user, &room_id, None).await { error!( - "Error attempting to make local user {} leave room {} during bulk room banning: {}", - &local_user, &room_id, e + "Error attempting to make local user {local_user} leave room {room_id} during bulk room \ + banning: {e}", ); + return Ok(RoomMessageEventContent::text_plain(format!( "Error attempting to make local user {} leave room {} during room banning (room is still \ banned but not removing any more users and not banning any more rooms): {}\nIf you would \ @@ -453,26 +433,26 @@ async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Resu } // remove any local aliases, ignore errors - for ref local_alias in self - .services + self.services .rooms .alias .local_aliases_for_room(&room_id) - .filter_map(Result::ok) - { - _ = self - .services - .rooms - .alias - .remove_alias(local_alias, &self.services.globals.server_user) - .await; - } + .map(ToOwned::to_owned) + .for_each(|local_alias| async move { + self.services + .rooms + .alias + .remove_alias(&local_alias, &self.services.globals.server_user) + .await + .ok(); + }) + .await; // unpublish from room directory, ignore errors - _ = self.services.rooms.directory.set_not_public(&room_id); + self.services.rooms.directory.set_not_public(&room_id); if disable_federation { - self.services.rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true); } } @@ -503,7 +483,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> debug!("Room specified is a room ID, unbanning room ID"); - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else if room.is_room_alias_id() { @@ -522,7 +502,13 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Ok(room_id) = self + .services + .rooms + .alias + .resolve_local_alias(&room_alias) + .await + { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); @@ -546,7 +532,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> } }; - self.services.rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false); room_id } else { @@ -557,7 +543,7 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> }; if enable_federation { - self.services.rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false); return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); } @@ -569,45 +555,42 @@ async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> #[admin_command] async fn list_banned_rooms(&self, no_details: bool) -> Result<RoomMessageEventContent> { - let rooms = self + let room_ids: Vec<OwnedRoomId> = self .services .rooms .metadata .list_banned_rooms() - .collect::<Result<Vec<_>, _>>(); + .map(Into::into) + .collect() + .await; - match rooms { - Ok(room_ids) => { - if room_ids.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); - } - - let mut rooms = room_ids - .into_iter() - .map(|room_id| get_room_info(self.services, &room_id)) - .collect::<Vec<_>>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let output_plain = format!( - "Rooms Banned ({}):\n```\n{}\n```", - rooms.len(), - rooms - .iter() - .map(|(id, members, name)| if no_details { - format!("{id}") - } else { - format!("{id}\tMembers: {members}\tName: {name}") - }) - .collect::<Vec<_>>() - .join("\n") - ); - - Ok(RoomMessageEventContent::notice_markdown(output_plain)) - }, - Err(e) => { - error!("Failed to list banned rooms: {e}"); - Ok(RoomMessageEventContent::text_plain(format!("Unable to list banned rooms: {e}"))) - }, + if room_ids.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No rooms are banned.")); } + + let mut rooms = room_ids + .iter() + .stream() + .then(|room_id| get_room_info(self.services, room_id)) + .collect::<Vec<_>>() + .await; + + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let output_plain = format!( + "Rooms Banned ({}):\n```\n{}\n```", + rooms.len(), + rooms + .iter() + .map(|(id, members, name)| if no_details { + format!("{id}") + } else { + format!("{id}\tMembers: {members}\tName: {name}") + }) + .collect::<Vec<_>>() + .join("\n") + ); + + Ok(RoomMessageEventContent::notice_markdown(output_plain)) } diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index de6ad98add809426972bf76f1375d1cb4be1f161..94f695ceb21587282d507bd99decc1e0ed6d606f 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -21,7 +21,10 @@ pub(super) async fn uptime(&self) -> Result<RoomMessageEventContent> { #[admin_command] pub(super) async fn show_config(&self) -> Result<RoomMessageEventContent> { // Construct and send the response - Ok(RoomMessageEventContent::text_plain(format!("{}", self.services.globals.config))) + Ok(RoomMessageEventContent::text_markdown(format!( + "```\n{}\n```", + self.services.globals.config + ))) } #[admin_command] @@ -104,7 +107,7 @@ pub(super) async fn backup_database(&self) -> Result<RoomMessageEventContent> { .runtime() .spawn_blocking(move || match globals.db.backup() { Ok(()) => String::new(), - Err(e) => (*e).to_string(), + Err(e) => e.to_string(), }) .await?; diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 20691f1a2644a69b470b4e7451dd9c01c9129f1a..444a7f372bfcdc23c18fac1de0ae0125c3a511ca 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,13 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{full_user_deactivate, join_room_by_id_helper, leave_room}; -use conduit::{error, info, utils, warn, PduBuilder, Result}; +use conduit::{ + debug_warn, error, info, is_equal_to, + utils::{self, ReadyExt}, + warn, PduBuilder, Result, +}; +use conduit_api::client::{leave_all_rooms, update_avatar_url, update_displayname}; +use futures::StreamExt; use ruma::{ events::{ room::{ @@ -10,11 +16,10 @@ redaction::RoomRedactionEventContent, }, tag::{TagEvent, TagEventContent, TagInfo}, - RoomAccountDataEventType, StateEventType, TimelineEventType, + RoomAccountDataEventType, StateEventType, }, - EventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, + EventId, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, UserId, }; -use serde_json::value::to_raw_value; use crate::{ admin_command, get_room_info, @@ -22,19 +27,23 @@ }; const AUTO_GEN_PASSWORD_LENGTH: usize = 25; +const BULK_JOIN_REASON: &str = "Bulk force joining this room as initiated by the server admin."; #[admin_command] pub(super) async fn list_users(&self) -> Result<RoomMessageEventContent> { - match self.services.users.list_local_users() { - Ok(users) => { - let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += users.join("\n").as_str(); - plain_msg += "\n```"; + let users = self + .services + .users + .list_local_users() + .map(ToString::to_string) + .collect::<Vec<_>>() + .await; - Ok(RoomMessageEventContent::notice_markdown(plain_msg)) - }, - Err(e) => Ok(RoomMessageEventContent::text_plain(e.to_string())), - } + let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); + plain_msg += users.join("\n").as_str(); + plain_msg += "\n```"; + + Ok(RoomMessageEventContent::notice_markdown(plain_msg)) } #[admin_command] @@ -42,7 +51,7 @@ pub(super) async fn create_user(&self, username: String, password: Option<String // Validate user id let user_id = parse_local_user_id(self.services, &username)?; - if self.services.users.exists(&user_id)? { + if self.services.users.exists(&user_id).await { return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists"))); } @@ -77,43 +86,51 @@ pub(super) async fn create_user(&self, username: String, password: Option<String self.services .users - .set_displayname(&user_id, Some(displayname)) - .await?; + .set_displayname(&user_id, Some(displayname)); // Initial account data - self.services.account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json value always works"), - )?; + self.services + .account_data + .update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json value always works"), + ) + .await?; if !self.services.globals.config.auto_join_rooms.is_empty() { for room in &self.services.globals.config.auto_join_rooms { + let Ok(room_id) = self.services.rooms.alias.resolve(room).await else { + error!(%user_id, "Failed to resolve room alias to room ID when attempting to auto join {room}, skipping"); + continue; + }; + if !self .services .rooms .state_cache - .server_in_room(self.services.globals.server_name(), room)? + .server_in_room(self.services.globals.server_name(), &room_id) + .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; } - if let Some(room_id_server_name) = room.server_name() { + if let Some(room_server_name) = room.server_name() { match join_room_by_id_helper( self.services, &user_id, - room, + &room_id, Some("Automatically joining this room upon registration".to_owned()), - &[room_id_server_name.to_owned(), self.services.globals.server_name().to_owned()], + &[self.services.globals.server_name().to_owned(), room_server_name.to_owned()], None, &None, ) @@ -123,6 +140,13 @@ pub(super) async fn create_user(&self, username: String, password: Option<String info!("Automatically joined room {room} for user {user_id}"); }, Err(e) => { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "Failed to automatically join room {room} for user {user_id}: {e}" + ))) + .await + .ok(); // don't return this error so we don't fail registrations error!("Failed to automatically join room {room} for user {user_id}: {e}"); }, @@ -135,13 +159,14 @@ pub(super) async fn create_user(&self, username: String, password: Option<String // if this account creation is from the CLI / --execute, invite the first user // to admin room - if let Some(admin_room) = self.services.admin.get_admin_room()? { + if let Ok(admin_room) = self.services.admin.get_admin_room().await { if self .services .rooms .state_cache - .room_joined_count(&admin_room)? - == Some(1) + .room_joined_count(&admin_room) + .await + .is_ok_and(is_equal_to!(1)) { self.services.admin.make_user_admin(&user_id).await?; @@ -167,7 +192,7 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> )); } - self.services.users.deactivate_account(&user_id)?; + self.services.users.deactivate_account(&user_id).await?; if !no_leave_rooms { self.services @@ -175,17 +200,22 @@ pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> .send_message(RoomMessageEventContent::text_plain(format!( "Making {user_id} leave all rooms after deactivation..." ))) - .await; + .await + .ok(); let all_joined_rooms: Vec<OwnedRoomId> = self .services .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms).await?; + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -238,15 +268,16 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(self.services, username) { + match parse_active_local_user_id(self.services, username).await { Ok(user_id) => { - if self.services.users.is_admin(&user_id)? && !force { + if self.services.users.is_admin(&user_id).await && !force { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is an admin and --force is not set, skipping over" ))) - .await; + .await + .ok(); admins.push(username); continue; } @@ -258,7 +289,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is the server service account, skipping over" ))) - .await; + .await + .ok(); continue; } @@ -270,7 +302,8 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .send_message(RoomMessageEventContent::text_plain(format!( "{username} is not a valid username, skipping over: {e}" ))) - .await; + .await + .ok(); continue; }, } @@ -279,7 +312,7 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> let mut deactivation_count: usize = 0; for user_id in user_ids { - match self.services.users.deactivate_account(&user_id) { + match self.services.users.deactivate_account(&user_id).await { Ok(()) => { deactivation_count = deactivation_count.saturating_add(1); if !no_leave_rooms { @@ -289,16 +322,26 @@ pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .collect(); - full_user_deactivate(self.services, &user_id, all_joined_rooms).await?; + .map(Into::into) + .collect() + .await; + + full_user_deactivate(self.services, &user_id, &all_joined_rooms).await?; + update_displayname(self.services, &user_id, None, &all_joined_rooms) + .await + .ok(); + update_avatar_url(self.services, &user_id, None, None, &all_joined_rooms) + .await + .ok(); + leave_all_rooms(self.services, &user_id).await; } }, Err(e) => { self.services .admin .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) - .await; + .await + .ok(); }, } } @@ -326,9 +369,9 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result<RoomMess .rooms .state_cache .rooms_joined(&user_id) - .filter_map(Result::ok) - .map(|room_id| get_room_info(self.services, &room_id)) - .collect(); + .then(|room_id| get_room_info(self.services, room_id)) + .collect() + .await; if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); @@ -350,18 +393,247 @@ pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result<RoomMess Ok(RoomMessageEventContent::notice_markdown(output_plain)) } +#[admin_command] +pub(super) async fn force_join_list_of_local_users( + &self, room_id: OwnedRoomOrAliasId, yes_i_want_to_do_this: bool, +) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { + return Ok(RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + )); + } + + if !yes_i_want_to_do_this { + return Ok(RoomMessageEventContent::notice_markdown( + "You must pass the --yes-i-want-to-do-this-flag to ensure you really want to force bulk join all \ + specified local users.", + )); + } + + let Ok(admin_room) = self.services.admin.get_admin_room().await else { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not an admin room to check for server admins.", + )); + }; + + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; + + if !self + .services + .rooms + .state_cache + .server_in_room(self.services.globals.server_name(), &room_id) + .await + { + return Ok(RoomMessageEventContent::notice_markdown("We are not joined in this room.")); + } + + let server_admins: Vec<_> = self + .services + .rooms + .state_cache + .active_local_users_in_room(&admin_room) + .map(ToOwned::to_owned) + .collect() + .await; + + if !self + .services + .rooms + .state_cache + .room_members(&room_id) + .ready_any(|user_id| server_admins.contains(&user_id.to_owned())) + .await + { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not a single server admin in the room.", + )); + } + + let usernames = self + .body + .to_vec() + .drain(1..self.body.len().saturating_sub(1)) + .collect::<Vec<_>>(); + + let mut user_ids: Vec<OwnedUserId> = Vec::with_capacity(usernames.len()); + + for username in usernames { + match parse_active_local_user_id(self.services, username).await { + Ok(user_id) => { + // don't make the server service account join + if user_id == self.services.globals.server_user { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "{username} is the server service account, skipping over" + ))) + .await + .ok(); + continue; + } + + user_ids.push(user_id); + }, + Err(e) => { + self.services + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "{username} is not a valid username, skipping over: {e}" + ))) + .await + .ok(); + continue; + }, + } + } + + let mut failed_joins: usize = 0; + let mut successful_joins: usize = 0; + + for user_id in user_ids { + match join_room_by_id_helper( + self.services, + &user_id, + &room_id, + Some(String::from(BULK_JOIN_REASON)), + &servers, + None, + &None, + ) + .await + { + Ok(_res) => { + successful_joins = successful_joins.saturating_add(1); + }, + Err(e) => { + debug_warn!("Failed force joining {user_id} to {room_id} during bulk join: {e}"); + failed_joins = failed_joins.saturating_add(1); + }, + }; + } + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{successful_joins} local users have been joined to {room_id}. {failed_joins} joins failed.", + ))) +} + +#[admin_command] +pub(super) async fn force_join_all_local_users( + &self, room_id: OwnedRoomOrAliasId, yes_i_want_to_do_this: bool, +) -> Result<RoomMessageEventContent> { + if !yes_i_want_to_do_this { + return Ok(RoomMessageEventContent::notice_markdown( + "You must pass the --yes-i-want-to-do-this-flag to ensure you really want to force bulk join all local \ + users.", + )); + } + + let Ok(admin_room) = self.services.admin.get_admin_room().await else { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not an admin room to check for server admins.", + )); + }; + + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; + + if !self + .services + .rooms + .state_cache + .server_in_room(self.services.globals.server_name(), &room_id) + .await + { + return Ok(RoomMessageEventContent::notice_markdown("We are not joined in this room.")); + } + + let server_admins: Vec<_> = self + .services + .rooms + .state_cache + .active_local_users_in_room(&admin_room) + .map(ToOwned::to_owned) + .collect() + .await; + + if !self + .services + .rooms + .state_cache + .room_members(&room_id) + .ready_any(|user_id| server_admins.contains(&user_id.to_owned())) + .await + { + return Ok(RoomMessageEventContent::notice_markdown( + "There is not a single server admin in the room.", + )); + } + + let mut failed_joins: usize = 0; + let mut successful_joins: usize = 0; + + for user_id in &self + .services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::<Vec<_>>() + .await + { + match join_room_by_id_helper( + self.services, + user_id, + &room_id, + Some(String::from(BULK_JOIN_REASON)), + &servers, + None, + &None, + ) + .await + { + Ok(_res) => { + successful_joins = successful_joins.saturating_add(1); + }, + Err(e) => { + debug_warn!("Failed force joining {user_id} to {room_id} during bulk join: {e}"); + failed_joins = failed_joins.saturating_add(1); + }, + }; + } + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{successful_joins} local users have been joined to {room_id}. {failed_joins} joins failed.", + ))) +} + #[admin_command] pub(super) async fn force_join_room( &self, user_id: String, room_id: OwnedRoomOrAliasId, ) -> Result<RoomMessageEventContent> { let user_id = parse_local_user_id(self.services, &user_id)?; - let room_id = self.services.rooms.alias.resolve(&room_id).await?; + let (room_id, servers) = self + .services + .rooms + .alias + .resolve_with_servers(&room_id, None) + .await?; assert!( self.services.globals.user_is_local(&user_id), "Parsed user_id must be a local user" ); - join_room_by_id_helper(self.services, &user_id, &room_id, None, &[], None, &None).await?; + join_room_by_id_helper(self.services, &user_id, &room_id, None, &servers, None, &None).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "{user_id} has been joined to {room_id}.", @@ -404,10 +676,9 @@ pub(super) async fn force_demote( .services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? - .as_ref() - .and_then(|event| serde_json::from_str(event.content.get()).ok()?) - .and_then(|content: RoomPowerLevelsEventContent| content.into()); + .room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -417,9 +688,9 @@ pub(super) async fn force_demote( .services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if !user_can_demote_self { return Ok(RoomMessageEventContent::notice_markdown( @@ -435,14 +706,7 @@ pub(super) async fn force_demote( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &power_levels_content), &user_id, &room_id, &state_lock, @@ -473,33 +737,33 @@ pub(super) async fn make_user_admin(&self, user_id: String) -> Result<RoomMessag pub(super) async fn put_room_tag( &self, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let mut tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; - - let mut tags_event = event.map_or_else( - || TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); tags_event .content .tags .insert(tag.clone().into(), TagInfo::new()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id} with tag {tag}" @@ -510,30 +774,30 @@ pub(super) async fn put_room_tag( pub(super) async fn delete_room_tag( &self, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let mut tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; - - let mut tags_event = event.map_or_else( - || TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); tags_event.content.tags.remove(&tag.clone().into()); - self.services.account_data.update( - Some(&room_id), - &user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + self.services + .account_data + .update( + Some(&room_id), + &user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Successfully updated room account data for {user_id} and room {room_id}, deleting room tag {tag}" @@ -542,21 +806,18 @@ pub(super) async fn delete_room_tag( #[admin_command] pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(self.services, &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id).await?; - let event = self + let tags_event = self .services .account_data - .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; - - let tags_event = event.map_or_else( - || TagEvent { + .get_room(&room_id, &user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }, - |e| serde_json::from_str(e.get()).expect("Bad account data in database for user {user_id}"), - ); + }); Ok(RoomMessageEventContent::notice_markdown(format!( "```\n{:#?}\n```", @@ -566,11 +827,12 @@ pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) #[admin_command] pub(super) async fn redact_event(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { - let Some(event) = self + let Ok(event) = self .services .rooms .timeline - .get_non_outlier_pdu(&event_id)? + .get_non_outlier_pdu(&event_id) + .await else { return Ok(RoomMessageEventContent::text_plain("Event does not exist in our database.")); }; @@ -599,16 +861,11 @@ pub(super) async fn redact_event(&self, event_id: Box<EventId>) -> Result<RoomMe .timeline .build_and_append_pdu( PduBuilder { - event_type: TimelineEventType::RoomRedaction, - content: to_raw_value(&RoomRedactionEventContent { + redacts: Some(event.event_id.clone()), + ..PduBuilder::timeline(&RoomRedactionEventContent { redacts: Some(event.event_id.clone().into()), reason: Some(reason), }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: Some(event.event_id), - timestamp: None, }, &sender_user, &room_id, diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index e7bb5c732914b03ec6cf7d4e477f058703f281b2..649cdfb874695f5a52d385ce0c411223276c2fce 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -124,4 +124,31 @@ pub(super) enum UserCommand { RedactEvent { event_id: Box<EventId>, }, + + /// - Force joins a specified list of local users to join the specified + /// room. + /// + /// Specify a codeblock of usernames. + /// + /// At least 1 server admin must be in the room to reduce abuse. + /// + /// Requires the `--yes-i-want-to-do-this` flag. + ForceJoinListOfLocalUsers { + room_id: OwnedRoomOrAliasId, + + #[arg(long)] + yes_i_want_to_do_this: bool, + }, + + /// - Force joins all local users to the specified room. + /// + /// At least 1 server admin must be in the room to reduce abuse. + /// + /// Requires the `--yes-i-want-to-do-this` flag. + ForceJoinAllLocalUsers { + room_id: OwnedRoomOrAliasId, + + #[arg(long)] + yes_i_want_to_do_this: bool, + }, } diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 8d3d15ae48019729b92a3ec62aba10ec685ac5e6..ba98bbeacc786ecd6b4bfbbd799db81a7c3259ff 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -8,23 +8,21 @@ pub(crate) fn escape_html(s: &str) -> String { .replace('>', ">") } -pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { +pub(crate) async fn get_room_info(services: &Services, room_id: &RoomId) -> (OwnedRoomId, u64, String) { ( - id.into(), + room_id.into(), services .rooms .state_cache - .room_joined_count(id) - .ok() - .flatten() + .room_joined_count(room_id) + .await .unwrap_or(0), services .rooms .state_accessor - .get_name(id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), + .get_name(room_id) + .await + .unwrap_or_else(|_| room_id.to_string()), ) } @@ -46,14 +44,14 @@ pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result< } /// Parses user ID that is an active (not guest or deactivated) local user -pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { +pub(crate) async fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { let user_id = parse_local_user_id(services, user_id)?; - if !services.users.exists(&user_id)? { + if !services.users.exists(&user_id).await { return Err!("User {user_id:?} does not exist on this server."); } - if services.users.is_deactivated(&user_id)? { + if services.users.is_deactivated(&user_id).await? { return Err!("User {user_id:?} is deactivated."); } diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 2b89c3e82ffc32c1462f152677822aa2e6091f75..a0fc09ded5fdb3b828a95d2203e0fa3036513043 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -45,7 +45,7 @@ conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hmac.workspace = true http.workspace = true http-body-util.workspace = true @@ -59,7 +59,7 @@ ruma.workspace = true serde_html_form.workspace = true serde_json.workspace = true serde.workspace = true -sha-1.workspace = true +sha1.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/api/client/account.rs b/src/api/client/account.rs index cee86f80a927c467554f8065bf663928a1c92960..5ed4b3127fc137055a34eb6af9b086e0441d87a2 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -2,7 +2,8 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug_info, error, info, utils, warn, Error, PduBuilder, Result}; +use conduit::{debug_info, error, info, is_equal_to, utils, utils::ReadyExt, warn, Error, PduBuilder, Result}; +use futures::{FutureExt, StreamExt}; use register::RegistrationKind; use ruma::{ api::client::{ @@ -20,11 +21,10 @@ message::RoomMessageEventContent, power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, }, - GlobalAccountDataEventType, StateEventType, TimelineEventType, + GlobalAccountDataEventType, StateEventType, }, push, OwnedRoomId, UserId, }; -use serde_json::value::to_raw_value; use service::Services; use super::{join_room_by_id_helper, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; @@ -48,14 +48,23 @@ pub(crate) async fn get_register_available_route( State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<get_username_availability::v3::Request>, ) -> Result<get_username_availability::v3::Response> { + // workaround for https://github.com/matrix-org/matrix-appservice-irc/issues/1780 due to inactivity of fixing the issue + let is_matrix_appservice_irc = body.appservice_info.as_ref().is_some_and(|appservice| { + appservice.registration.id == "irc" + || appservice.registration.id.contains("matrix-appservice-irc") + || appservice.registration.id.contains("matrix_appservice_irc") + }); + // Validate user id let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) + .filter(|user_id| { + (!user_id.is_historical() || is_matrix_appservice_irc) && services.globals.user_is_local(user_id) + }) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough - if services.users.exists(&user_id)? { + if services.users.exists(&user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -100,8 +109,8 @@ pub(crate) async fn register_route( if !services.globals.allow_registration() && body.appservice_info.is_none() { info!( "Registration disabled and request not from known appservice, rejecting registration attempt for username \ - {:?}", - body.username + \"{}\"", + body.username.as_deref().unwrap_or("") ); return Err(Error::BadRequest(ErrorKind::forbidden(), "Registration has been disabled.")); } @@ -110,12 +119,12 @@ pub(crate) async fn register_route( if is_guest && (!services.globals.allow_guest_registration() - || (services.globals.allow_registration() && services.globals.config.registration_token.is_some())) + || (services.globals.allow_registration() && services.globals.registration_token.is_some())) { info!( "Guest registration disabled / registration enabled with token configured, rejecting guest registration \ - attempt, initial device name: {:?}", - body.initial_device_display_name + attempt, initial device name: \"{}\"", + body.initial_device_display_name.as_deref().unwrap_or("") ); return Err(Error::BadRequest( ErrorKind::GuestAccessForbidden, @@ -125,24 +134,34 @@ pub(crate) async fn register_route( // forbid guests from registering if there is not a real admin user yet. give // generic user error. - if is_guest && services.users.count()? < 2 { + if is_guest && services.users.count().await < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ - registration. Guest's initial device name: {:?}", - body.initial_device_display_name + registration. Guest's initial device name: \"{}\"", + body.initial_device_display_name.as_deref().unwrap_or("") ); return Err(Error::BadRequest(ErrorKind::forbidden(), "Registration temporarily disabled.")); } + // workaround for https://github.com/matrix-org/matrix-appservice-irc/issues/1780 due to inactivity of fixing the issue + let is_matrix_appservice_irc = body.appservice_info.as_ref().is_some_and(|appservice| { + appservice.registration.id == "irc" + || appservice.registration.id.contains("matrix-appservice-irc") + || appservice.registration.id.contains("matrix_appservice_irc") + }); + let user_id = match (&body.username, is_guest) { (Some(username), false) => { let proposed_user_id = UserId::parse_with_server_name(username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) + .filter(|user_id| { + (!user_id.is_historical() || is_matrix_appservice_irc) + && services.globals.user_is_local(user_id) + }) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services.users.exists(&proposed_user_id)? { + if services.users.exists(&proposed_user_id).await { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } @@ -162,7 +181,7 @@ pub(crate) async fn register_route( services.globals.server_name(), ) .unwrap(); - if !services.users.exists(&proposed_user_id)? { + if !services.users.exists(&proposed_user_id).await { break proposed_user_id; } }, @@ -182,7 +201,7 @@ pub(crate) async fn register_route( // UIAA let mut uiaainfo; - let skip_auth = if services.globals.config.registration_token.is_some() { + let skip_auth = if services.globals.registration_token.is_some() { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -210,12 +229,15 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services.uiaa.try_auth( - &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), - "".into(), - auth, - &uiaainfo, - )?; + let (worked, uiaainfo) = services + .uiaa + .try_auth( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), + "".into(), + auth, + &uiaainfo, + ) + .await?; if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -227,7 +249,7 @@ pub(crate) async fn register_route( "".into(), &uiaainfo, &json, - )?; + ); return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -255,21 +277,23 @@ pub(crate) async fn register_route( services .users - .set_displayname(&user_id, Some(displayname.clone())) - .await?; + .set_displayname(&user_id, Some(displayname.clone())); // Initial account data - services.account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + ) + .await?; // Inhibit login does not work for guests if !is_guest && body.inhibit_login { @@ -294,22 +318,27 @@ pub(crate) async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; debug_info!(%user_id, %device_id, "User account was created"); - let device_display_name = body.initial_device_display_name.clone().unwrap_or_default(); + let device_display_name = body.initial_device_display_name.as_deref().unwrap_or(""); // log in conduit admin channel if a non-guest user registered if body.appservice_info.is_none() && !is_guest { if !device_display_name.is_empty() { - info!("New user \"{user_id}\" registered on this server with device display name: {device_display_name}"); + info!( + "New user \"{user_id}\" registered on this server with device display name: \"{device_display_name}\"" + ); if services.globals.config.admin_room_notices { services @@ -318,7 +347,8 @@ pub(crate) async fn register_route( "New user \"{user_id}\" registered on this server from IP {client} and device display name \ \"{device_display_name}\"" ))) - .await; + .await + .ok(); } } else { info!("New user \"{user_id}\" registered on this server."); @@ -329,7 +359,8 @@ pub(crate) async fn register_route( .send_message(RoomMessageEventContent::notice_plain(format!( "New user \"{user_id}\" registered on this server from IP {client}" ))) - .await; + .await + .ok(); } } } @@ -346,7 +377,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with device display name \"{device_display_name}\" registered on \ this server from IP {client}" ))) - .await; + .await + .ok(); } } else { #[allow(clippy::collapsible_else_if)] @@ -357,7 +389,8 @@ pub(crate) async fn register_route( "Guest user \"{user_id}\" with no device display name registered on this server from IP \ {client}", ))) - .await; + .await + .ok(); } } } @@ -365,10 +398,15 @@ pub(crate) async fn register_route( // If this is the first real user, grant them admin privileges except for guest // users Note: the server user, @conduit:servername, is generated first if !is_guest { - if let Some(admin_room) = services.admin.get_admin_room()? { - if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { + if let Ok(admin_room) = services.admin.get_admin_room().await { + if services + .rooms + .state_cache + .room_joined_count(&admin_room) + .await + .is_ok_and(is_equal_to!(1)) + { services.admin.make_user_admin(&user_id).await?; - warn!("Granting {user_id} admin privileges as the first user"); } } @@ -379,25 +417,32 @@ pub(crate) async fn register_route( && (services.globals.allow_guests_auto_join_rooms() || !is_guest) { for room in &services.globals.config.auto_join_rooms { + let Ok(room_id) = services.rooms.alias.resolve(room).await else { + error!("Failed to resolve room alias to room ID when attempting to auto join {room}, skipping"); + continue; + }; + if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room)? + .server_in_room(services.globals.server_name(), &room_id) + .await { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; } - if let Some(room_id_server_name) = room.server_name() { + if let Some(room_server_name) = room.server_name() { if let Err(e) = join_room_by_id_helper( &services, &user_id, - room, + &room_id, Some("Automatically joining this room upon registration".to_owned()), - &[room_id_server_name.to_owned(), services.globals.server_name().to_owned()], + &[services.globals.server_name().to_owned(), room_server_name.to_owned()], None, &body.appservice_info, ) + .boxed() .await { // don't return this error so we don't fail registrations @@ -461,16 +506,20 @@ pub(crate) async fn change_password_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } - // Success! + + // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -482,14 +531,12 @@ pub(crate) async fn change_password_route( if body.logout_devices { // Logout all devices except the current one - for id in services + services .users .all_device_ids(sender_user) - .filter_map(Result::ok) - .filter(|id| id != sender_device) - { - services.users.remove_device(sender_user, &id)?; - } + .ready_filter(|id| id != sender_device) + .for_each(|id| services.users.remove_device(sender_user, id)) + .await; } info!("User {sender_user} changed their password."); @@ -500,7 +547,8 @@ pub(crate) async fn change_password_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} changed their password." ))) - .await; + .await + .ok(); } Ok(change_password::v3::Response {}) @@ -520,7 +568,7 @@ pub(crate) async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services.users.is_deactivated(sender_user).await? && body.appservice_info.is_none(), }) } @@ -561,7 +609,9 @@ pub(crate) async fn deactivate_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -570,7 +620,8 @@ pub(crate) async fn deactivate_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); @@ -581,10 +632,14 @@ pub(crate) async fn deactivate_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; + + super::update_displayname(&services, sender_user, None, &all_joined_rooms).await?; + super::update_avatar_url(&services, sender_user, None, None, &all_joined_rooms).await?; - full_user_deactivate(&services, sender_user, all_joined_rooms).await?; + full_user_deactivate(&services, sender_user, &all_joined_rooms).await?; info!("User {sender_user} deactivated their account."); @@ -594,7 +649,8 @@ pub(crate) async fn deactivate_route( .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} deactivated their account." ))) - .await; + .await + .ok(); } Ok(deactivate::v3::Response { @@ -654,7 +710,7 @@ pub(crate) async fn request_3pid_management_token_via_msisdn_route( pub(crate) async fn check_registration_token_validity( State(services): State<crate::State>, body: Ruma<check_registration_token_validity::v1::Request>, ) -> Result<check_registration_token_validity::v1::Response> { - let Some(reg_token) = services.globals.config.registration_token.clone() else { + let Some(reg_token) = services.globals.registration_token.clone() else { return Err(Error::BadRequest( ErrorKind::forbidden(), "Server does not allow token registration.", @@ -674,34 +730,27 @@ pub(crate) async fn check_registration_token_validity( /// - Removing all profile data /// - Leaving all rooms (and forgets all of them) pub async fn full_user_deactivate( - services: &Services, user_id: &UserId, all_joined_rooms: Vec<OwnedRoomId>, + services: &Services, user_id: &UserId, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - services.users.deactivate_account(user_id)?; + services.users.deactivate_account(user_id).await?; + super::update_displayname(services, user_id, None, all_joined_rooms).await?; + super::update_avatar_url(services, user_id, None, None, all_joined_rooms).await?; - super::update_displayname(services, user_id, None, all_joined_rooms.clone()).await?; - super::update_avatar_url(services, user_id, None, None, all_joined_rooms.clone()).await?; - - let all_profile_keys = services + services .users .all_profile_keys(user_id) - .filter_map(Result::ok); - - for (profile_key, _profile_value) in all_profile_keys { - if let Err(e) = services.users.set_profile_key(user_id, &profile_key, None) { - warn!("Failed removing {user_id} profile key {profile_key}: {e}"); - } - } + .ready_for_each(|(profile_key, _)| services.users.set_profile_key(user_id, &profile_key, None)) + .await; for room_id in all_joined_rooms { - let state_lock = services.rooms.state.mutex.lock(&room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let room_power_levels = services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? - .as_ref() - .and_then(|event| serde_json::from_str(event.content.get()).ok()?) - .and_then(|content: RoomPowerLevelsEventContent| content.into()); + .room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "") + .await + .ok(); let user_can_demote_self = room_power_levels .as_ref() @@ -710,9 +759,9 @@ pub async fn full_user_deactivate( }) || services .rooms .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .as_ref() - .is_some_and(|event| event.sender == user_id); + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + .is_ok_and(|event| event.sender == user_id); if user_can_demote_self { let mut power_levels_content = room_power_levels.unwrap_or_default(); @@ -723,16 +772,9 @@ pub async fn full_user_deactivate( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &power_levels_content), user_id, - &room_id, + room_id, &state_lock, ) .await diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 12d6352c946e53503a0cd1803001e0e6b7ac153d..83f3291d45d9cfbc195526565373419bbe4e5fff 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,11 +1,9 @@ use axum::extract::State; -use conduit::{debug, Error, Result}; +use conduit::{debug, Err, Result}; +use futures::StreamExt; use rand::seq::SliceRandom; use ruma::{ - api::client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, + api::client::alias::{create_alias, delete_alias, get_alias}, OwnedServerName, RoomAliasId, RoomId, }; use service::Services; @@ -33,16 +31,17 @@ pub(crate) async fn create_alias_route( .forbidden_alias_names() .is_match(body.room_alias.alias()) { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); + return Err!(Request(Forbidden("Room alias is forbidden."))); } if services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .is_some() + .resolve_local_alias(&body.room_alias) + .await + .is_ok() { - return Err(Error::Conflict("Alias already exists.")); + return Err!(Conflict("Alias already exists.")); } services @@ -87,39 +86,32 @@ pub(crate) async fn get_alias_route( State(services): State<crate::State>, body: Ruma<get_alias::v3::Request>, ) -> Result<get_alias::v3::Response> { let room_alias = body.body.room_alias; - let servers = None; - let Ok((room_id, pre_servers)) = services - .rooms - .alias - .resolve_alias(&room_alias, servers.as_ref()) - .await - else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); + let Ok((room_id, servers)) = services.rooms.alias.resolve_alias(&room_alias, None).await else { + return Err!(Request(NotFound("Room with alias not found."))); }; - let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); + let servers = room_available_servers(&services, &room_id, &room_alias, servers).await; debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } -fn room_available_servers( - services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>, +async fn room_available_servers( + services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: Vec<OwnedServerName>, ) -> Vec<OwnedServerName> { // find active servers in room state cache to suggest let mut servers: Vec<OwnedServerName> = services .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; // push any servers we want in the list already (e.g. responded remote alias // servers, room alias server itself) - if let Some(pre_servers) = pre_servers { - servers.extend(pre_servers.clone()); - }; + servers.extend(pre_servers); servers.sort_unstable(); servers.dedup(); diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index 4ead87776fad9569c24c74fbcdb846413facc840..f435e08699117ff412f87802aed73c70969b9571 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -1,18 +1,16 @@ use axum::extract::State; +use conduit::{err, Err}; use ruma::{ - api::client::{ - backup::{ - add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, - delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, - get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, - get_latest_backup_info, update_backup_version, - }, - error::ErrorKind, + api::client::backup::{ + add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, + delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, + get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, }, UInt, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/room_keys/version` /// @@ -20,10 +18,9 @@ pub(crate) async fn create_backup_version_route( State(services): State<crate::State>, body: Ruma<create_backup_version::v3::Request>, ) -> Result<create_backup_version::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let version = services .key_backups - .create_backup(sender_user, &body.algorithm)?; + .create_backup(body.sender_user(), &body.algorithm)?; Ok(create_backup_version::v3::Response { version, @@ -37,10 +34,10 @@ pub(crate) async fn create_backup_version_route( pub(crate) async fn update_backup_version_route( State(services): State<crate::State>, body: Ruma<update_backup_version::v3::Request>, ) -> Result<update_backup_version::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); services .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + .update_backup(body.sender_user(), &body.version, &body.algorithm) + .await?; Ok(update_backup_version::v3::Response {}) } @@ -51,18 +48,25 @@ pub(crate) async fn update_backup_version_route( pub(crate) async fn get_latest_backup_info_route( State(services): State<crate::State>, body: Ruma<get_latest_backup_info::v3::Request>, ) -> Result<get_latest_backup_info::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let (version, algorithm) = services .key_backups - .get_latest_backup(sender_user)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_latest_backup(body.sender_user()) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist."))))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &version)?, + count: (UInt::try_from( + services + .key_backups + .count_keys(body.sender_user(), &version) + .await, + ) + .expect("user backup keys count should not be that high")), + etag: services + .key_backups + .get_etag(body.sender_user(), &version) + .await, version, }) } @@ -73,21 +77,23 @@ pub(crate) async fn get_latest_backup_info_route( pub(crate) async fn get_backup_info_route( State(services): State<crate::State>, body: Ruma<get_backup_info::v3::Request>, ) -> Result<get_backup_info::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let algorithm = services .key_backups - .get_backup(sender_user, &body.version)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; + .get_backup(body.sender_user(), &body.version) + .await + .map_err(|_| err!(Request(NotFound("Key backup does not exist at version {:?}", body.version))))?; Ok(get_backup_info::v3::Response { algorithm, - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, version: body.version.clone(), }) } @@ -101,11 +107,10 @@ pub(crate) async fn get_backup_info_route( pub(crate) async fn delete_backup_version_route( State(services): State<crate::State>, body: Ruma<delete_backup_version::v3::Request>, ) -> Result<delete_backup_version::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_backup(sender_user, &body.version)?; + .delete_backup(body.sender_user(), &body.version) + .await; Ok(delete_backup_version::v3::Response {}) } @@ -121,36 +126,36 @@ pub(crate) async fn delete_backup_version_route( pub(crate) async fn add_backup_keys_route( State(services): State<crate::State>, body: Ruma<add_backup_keys::v3::Request>, ) -> Result<add_backup_keys::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(body.sender_user()) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { services .key_backups - .add_key(sender_user, &body.version, room_id, session_id, key_data)?; + .add_key(body.sender_user(), &body.version, room_id, session_id, key_data) + .await?; } } Ok(add_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } @@ -165,34 +170,34 @@ pub(crate) async fn add_backup_keys_route( pub(crate) async fn add_backup_keys_for_room_route( State(services): State<crate::State>, body: Ruma<add_backup_keys_for_room::v3::Request>, ) -> Result<add_backup_keys_for_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(body.sender_user()) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } for (session_id, key_data) in &body.sessions { services .key_backups - .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + .add_key(body.sender_user(), &body.version, &body.room_id, session_id, key_data) + .await?; } Ok(add_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } @@ -207,32 +212,38 @@ pub(crate) async fn add_backup_keys_for_room_route( pub(crate) async fn add_backup_keys_for_session_route( State(services): State<crate::State>, body: Ruma<add_backup_keys_for_session::v3::Request>, ) -> Result<add_backup_keys_for_session::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if Some(&body.version) - != services - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() + if services + .key_backups + .get_latest_backup_version(body.sender_user()) + .await + .is_ok_and(|version| version != body.version) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); + return Err!(Request(InvalidParam( + "You may only manipulate the most recently created version of the backup." + ))); } services .key_backups - .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; + .add_key( + body.sender_user(), + &body.version, + &body.room_id, + &body.session_id, + &body.session_data, + ) + .await?; Ok(add_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } @@ -242,9 +253,10 @@ pub(crate) async fn add_backup_keys_for_session_route( pub(crate) async fn get_backup_keys_route( State(services): State<crate::State>, body: Ruma<get_backup_keys::v3::Request>, ) -> Result<get_backup_keys::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let rooms = services.key_backups.get_all(sender_user, &body.version)?; + let rooms = services + .key_backups + .get_all(body.sender_user(), &body.version) + .await; Ok(get_backup_keys::v3::Response { rooms, @@ -257,11 +269,10 @@ pub(crate) async fn get_backup_keys_route( pub(crate) async fn get_backup_keys_for_room_route( State(services): State<crate::State>, body: Ruma<get_backup_keys_for_room::v3::Request>, ) -> Result<get_backup_keys_for_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + .get_room(body.sender_user(), &body.version, &body.room_id) + .await; Ok(get_backup_keys_for_room::v3::Response { sessions, @@ -274,12 +285,11 @@ pub(crate) async fn get_backup_keys_for_room_route( pub(crate) async fn get_backup_keys_for_session_route( State(services): State<crate::State>, body: Ruma<get_backup_keys_for_session::v3::Request>, ) -> Result<get_backup_keys_for_session::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = services .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; + .get_session(body.sender_user(), &body.version, &body.room_id, &body.session_id) + .await + .map_err(|_| err!(Request(NotFound(debug_error!("Backup key not found for this user's session.")))))?; Ok(get_backup_keys_for_session::v3::Response { key_data, @@ -292,20 +302,21 @@ pub(crate) async fn get_backup_keys_for_session_route( pub(crate) async fn delete_backup_keys_route( State(services): State<crate::State>, body: Ruma<delete_backup_keys::v3::Request>, ) -> Result<delete_backup_keys::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_all_keys(sender_user, &body.version)?; + .delete_all_keys(body.sender_user(), &body.version) + .await; Ok(delete_backup_keys::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } @@ -315,20 +326,21 @@ pub(crate) async fn delete_backup_keys_route( pub(crate) async fn delete_backup_keys_for_room_route( State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_room::v3::Request>, ) -> Result<delete_backup_keys_for_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + .delete_room_keys(body.sender_user(), &body.version, &body.room_id) + .await; Ok(delete_backup_keys_for_room::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } @@ -338,19 +350,20 @@ pub(crate) async fn delete_backup_keys_for_room_route( pub(crate) async fn delete_backup_keys_for_session_route( State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_session::v3::Request>, ) -> Result<delete_backup_keys_for_session::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services .key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; + .delete_room_key(body.sender_user(), &body.version, &body.room_id, &body.session_id) + .await; Ok(delete_backup_keys_for_session::v3::Response { - count: (UInt::try_from( - services - .key_backups - .count_keys(sender_user, &body.version)?, - ) - .expect("user backup keys count should not be that high")), - etag: services.key_backups.get_etag(sender_user, &body.version)?, + count: services + .key_backups + .count_keys(body.sender_user(), &body.version) + .await + .try_into()?, + etag: services + .key_backups + .get_etag(body.sender_user(), &body.version) + .await, }) } diff --git a/src/api/client/capabilities.rs b/src/api/client/capabilities.rs index 83e1dc7e6a4ab0db8d6388774743f74679793227..89157e4711ff334cf26a2d8085884dee8ded36f4 100644 --- a/src/api/client/capabilities.rs +++ b/src/api/client/capabilities.rs @@ -3,7 +3,8 @@ use axum::extract::State; use ruma::{ api::client::discovery::get_capabilities::{ - self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability, + self, Capabilities, GetLoginTokenCapability, RoomVersionStability, RoomVersionsCapability, + ThirdPartyIdChangesCapability, }, RoomVersionId, }; @@ -43,6 +44,11 @@ pub(crate) async fn get_capabilities_route( enabled: false, }; + // we dont support generating tokens yet + capabilities.get_login_token = GetLoginTokenCapability { + enabled: false, + }; + // MSC4133 capability capabilities .set("uk.tcpip.msc4133.profile_fields", json!({"enabled": true})) diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 61cc97ff5dc914403eaf11b5348441868886b03d..3cf7113535c347611fdf71f51c7859a88e9ae6f3 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::err; use ruma::{ api::client::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, @@ -22,10 +23,11 @@ pub(crate) async fn set_global_account_data_route( set_account_data( &services, None, - &body.sender_user, + body.sender_user.as_ref(), &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_global_account_data::v3::Response {}) } @@ -39,10 +41,11 @@ pub(crate) async fn set_room_account_data_route( set_account_data( &services, Some(&body.room_id), - &body.sender_user, + body.sender_user.as_ref(), &body.event_type.to_string(), body.data.json(), - )?; + ) + .await?; Ok(set_room_account_data::v3::Response {}) } @@ -55,17 +58,14 @@ pub(crate) async fn get_global_account_data_route( ) -> Result<get_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = services + let account_data: ExtractGlobalEventContent = services .account_data - .get(None, sender_user, body.event_type.to_string().into())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - - let account_data = serde_json::from_str::<ExtractGlobalEventContent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + .get_global(sender_user, body.event_type.clone()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; Ok(get_global_account_data::v3::Response { - account_data, + account_data: account_data.content, }) } @@ -77,22 +77,19 @@ pub(crate) async fn get_room_account_data_route( ) -> Result<get_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = services + let account_data: ExtractRoomEventContent = services .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - - let account_data = serde_json::from_str::<ExtractRoomEventContent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + .get_room(&body.room_id, sender_user, body.event_type.clone()) + .await + .map_err(|_| err!(Request(NotFound("Data not found."))))?; Ok(get_room_account_data::v3::Response { - account_data, + account_data: account_data.content, }) } -fn set_account_data( - services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str, +async fn set_account_data( + services: &Services, room_id: Option<&RoomId>, sender_user: Option<&OwnedUserId>, event_type: &str, data: &RawJsonValue, ) -> Result<()> { let sender_user = sender_user.as_ref().expect("user is authenticated"); @@ -100,15 +97,18 @@ fn set_account_data( let data: serde_json::Value = serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - services.account_data.update( - room_id, - sender_user, - event_type.into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services + .account_data + .update( + room_id, + sender_user, + event_type.into(), + &json!({ + "type": event_type, + "content": data, + }), + ) + .await?; Ok(()) } diff --git a/src/api/client/context.rs b/src/api/client/context.rs index f223d4889150e52670ffc698a62f0f6faf7ffd82..4359ae121d29c489f2ed9ccf3b2e90f5a8bca520 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,15 +1,27 @@ -use std::collections::HashSet; +use std::iter::once; use axum::extract::State; +use conduit::{ + at, err, error, + utils::{future::TryExtExt, stream::ReadyExt, IterStream}, + Err, Result, +}; +use futures::{future::try_join, StreamExt, TryFutureExt}; use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, + UserId, +}; + +use crate::{ + client::message::{event_filter, ignored_filter, update_lazy, visibility_filter, LazySet}, + Ruma, }; -use tracing::error; -use crate::{Error, Result, Ruma}; +const LIMIT_MAX: usize = 100; +const LIMIT_DEFAULT: usize = 10; -/// # `GET /_matrix/client/r0/rooms/{roomId}/context` +/// # `GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}` /// /// Allows loading room history around an event. /// @@ -18,186 +30,166 @@ pub(crate) async fn get_context_route( State(services): State<crate::State>, body: Ruma<get_context::v3::Request>, ) -> Result<get_context::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let filter = &body.filter; + let sender = body.sender(); + let (sender_user, _) = sender; + + // Use limit or else 10, with maximum 100 + let limit: usize = body + .limit + .try_into() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); // some clients, at least element, seem to require knowledge of redundant // members for "inline" profiles on the timeline to work properly - let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members, - } => (true, *include_redundant_members), - LazyLoadOptions::Disabled => (false, cfg!(feature = "element_hacks")), - }; + let lazy_load_enabled = matches!(filter.lazy_load_options, LazyLoadOptions::Enabled { .. }); - let mut lazy_loaded = HashSet::new(); + let lazy_load_redundant = if let LazyLoadOptions::Enabled { + include_redundant_members, + } = filter.lazy_load_options + { + include_redundant_members + } else { + false + }; let base_token = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; + .get_pdu_count(&body.event_id) + .map_err(|_| err!(Request(NotFound("Event not found.")))); let base_event = services .rooms .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; + .get_pdu(&body.event_id) + .map_err(|_| err!(Request(NotFound("Base event not found.")))); - let room_id = base_event.room_id.clone(); + let (base_token, base_event) = try_join(base_token, base_event).await?; + + let room_id = &base_event.room_id; if !services .rooms .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? + .user_can_see_event(sender_user, room_id, &body.event_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); + return Err!(Request(Forbidden("You don't have permission to view this event."))); } - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &base_event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(base_event.sender.as_str().to_owned()); - } - - // Use limit or else 10, with maximum 100 - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let base_event = base_event.to_room_event(); - let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, &room_id, base_token)? + .pdus_rev(Some(sender_user), room_id, Some(base_token)) + .await? + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { - services - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) - }) - .collect(); - - for (_, event) in &events_before { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } - - let start_token = events_before - .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + .collect() + .await; let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, &room_id, base_token)? + .pdus(Some(sender_user), room_id, Some(base_token)) + .await? + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit / 2) - .filter_map(Result::ok) // Remove buggy events - .filter(|(_, pdu)| { - services - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) + .collect() + .await; + + let lazy = once(&(base_token, (*base_event).clone())) + .chain(events_before.iter()) + .chain(events_after.iter()) + .stream() + .fold(LazySet::new(), |lazy, item| { + update_lazy(&services, room_id, sender, lazy, item, lazy_load_redundant) }) - .collect(); - - for (_, event) in &events_after { - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } + .await; + + let state_id = events_after + .last() + .map_or(body.event_id.as_ref(), |(_, e)| e.event_id.as_ref()); let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash( - events_after - .last() - .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? - .map_or( - services - .rooms - .state - .get_room_shortstatehash(&room_id)? - .expect("All rooms have state"), - |hash| hash, - ); + .pdu_shortstatehash(state_id) + .or_else(|_| services.rooms.state.get_room_shortstatehash(room_id)) + .await + .map_err(|e| err!(Database("State hash not found: {e}")))?; let state_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await?; + .await + .map_err(|e| err!(Database("State not found: {e}")))?; + + let lazy = &lazy; + let state: Vec<_> = state_ids + .iter() + .stream() + .filter_map(|(shortstatekey, event_id)| { + services + .rooms + .short + .get_statekey_from_short(*shortstatekey) + .map_ok(move |(event_type, state_key)| (event_type, state_key, event_id)) + .ok() + }) + .filter_map(|(event_type, state_key, event_id)| async move { + if lazy_load_enabled && event_type == StateEventType::RoomMember { + let user_id: &UserId = state_key.as_str().try_into().ok()?; + if !lazy.contains(user_id) { + return None; + } + } - let end_token = events_after - .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - - let mut state = Vec::with_capacity(state_ids.len()); - - for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey)?; - - if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - state.push(pdu.to_state_event()); - } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - state.push(pdu.to_state_event()); - } - } + services + .rooms + .timeline + .get_pdu(event_id) + .await + .inspect_err(|_| error!("Pdu in state not found: {event_id}")) + .map(|pdu| pdu.to_state_event()) + .ok() + }) + .collect() + .await; Ok(get_context::v3::Response { - start: Some(start_token), - end: Some(end_token), - events_before, - event: Some(base_event), - events_after, + event: Some(base_event.to_room_event()), + + start: events_before + .last() + .map(at!(0)) + .as_ref() + .map(ToString::to_string), + + end: events_after + .last() + .map(at!(0)) + .as_ref() + .map(ToString::to_string), + + events_before: events_before + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) + .collect(), + + events_after: events_after + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) + .collect(), + state, }) } diff --git a/src/api/client/device.rs b/src/api/client/device.rs index bad7f28449da18f12ba9199a20c1dde5aed0a1ce..7e56f27e1bafb15c0f3cf6285dd3cb0e9808b4e1 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,8 +1,14 @@ use axum::extract::State; -use ruma::api::client::{ - device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, - error::ErrorKind, - uiaa::{AuthFlow, AuthType, UiaaInfo}, +use axum_client_ip::InsecureClientIp; +use conduit::{err, Err}; +use futures::StreamExt; +use ruma::{ + api::client::{ + device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, + error::ErrorKind, + uiaa::{AuthFlow, AuthType, UiaaInfo}, + }, + MilliSecondsSinceUnixEpoch, }; use super::SESSION_ID_LENGTH; @@ -19,8 +25,8 @@ pub(crate) async fn get_devices_route( let devices: Vec<device::Device> = services .users .all_devices_metadata(sender_user) - .filter_map(Result::ok) // Filter out buggy devices - .collect(); + .collect() + .await; Ok(get_devices::v3::Response { devices, @@ -37,8 +43,9 @@ pub(crate) async fn get_device_route( let device = services .users - .get_device_metadata(sender_user, &body.body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; Ok(get_device::v3::Response { device, @@ -48,21 +55,29 @@ pub(crate) async fn get_device_route( /// # `PUT /_matrix/client/r0/devices/{deviceId}` /// /// Updates the metadata on a given device of the sender user. +#[tracing::instrument(skip_all, fields(%client), name = "update_device")] pub(crate) async fn update_device_route( - State(services): State<crate::State>, body: Ruma<update_device::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<update_device::v3::Request>, ) -> Result<update_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut device = services .users - .get_device_metadata(sender_user, &body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + .get_device_metadata(sender_user, &body.device_id) + .await + .map_err(|_| err!(Request(NotFound("Device not found."))))?; device.display_name.clone_from(&body.display_name); + device.last_seen_ip.clone_from(&Some(client.to_string())); + device + .last_seen_ts + .clone_from(&Some(MilliSecondsSinceUnixEpoch::now())); services .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + .update_device_metadata(sender_user, &body.device_id, &device) + .await?; Ok(update_device::v3::Response {}) } @@ -97,22 +112,28 @@ pub(crate) async fn delete_device_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { - return Err(Error::Uiaa(uiaainfo)); + return Err!(Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); + .create(sender_user, sender_device, &uiaainfo, &json); + + return Err!(Uiaa(uiaainfo)); } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + return Err!(Request(NotJson("Not json."))); } - services.users.remove_device(sender_user, &body.device_id)?; + services + .users + .remove_device(sender_user, &body.device_id) + .await; Ok(delete_device::v3::Response {}) } @@ -149,7 +170,9 @@ pub(crate) async fn delete_devices_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -158,14 +181,15 @@ pub(crate) async fn delete_devices_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } for device_id in &body.devices { - services.users.remove_device(sender_user, device_id)?; + services.users.remove_device(sender_user, device_id).await; } Ok(delete_devices::v3::Response {}) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 602f876a9c960d550285b633d0114f194efee926..6120c7b399fd4b5ac9d0fc65d790f7b23c9d0897 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,6 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{err, info, warn, Err, Error, Result}; +use conduit::{info, warn, Err, Error, Result}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ api::{ client::{ @@ -18,7 +19,7 @@ }, StateEventType, }, - uint, RoomId, ServerName, UInt, UserId, + uint, OwnedRoomId, RoomId, ServerName, UInt, UserId, }; use service::Services; @@ -36,14 +37,12 @@ pub(crate) async fn get_public_rooms_filtered_route( ) -> Result<get_public_rooms_filtered::v3::Response> { if let Some(server) = &body.server { if services - .globals - .forbidden_remote_room_directory_server_names() + .server + .config + .forbidden_remote_room_directory_server_names .contains(server) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -76,14 +75,12 @@ pub(crate) async fn get_public_rooms_route( ) -> Result<get_public_rooms::v3::Response> { if let Some(server) = &body.server { if services - .globals - .forbidden_remote_room_directory_server_names() + .server + .config + .forbidden_remote_room_directory_server_names .contains(server) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -119,16 +116,22 @@ pub(crate) async fn set_room_visibility_route( ) -> Result<set_room_visibility::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if services.users.is_deactivated(sender_user).unwrap_or(false) && body.appservice_info.is_none() { + if services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && body.appservice_info.is_none() + { return Err!(Request(Forbidden("Guests cannot publish to room directories"))); } - if !user_can_publish_room(&services, sender_user, &body.room_id)? { + if !user_can_publish_room(&services, sender_user, &body.room_id).await? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", @@ -138,7 +141,7 @@ pub(crate) async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { if services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -164,7 +167,7 @@ pub(crate) async fn set_room_visibility_route( )); } - services.rooms.directory.set_public(&body.room_id)?; + services.rooms.directory.set_public(&body.room_id); if services.globals.config.admin_room_notices { services @@ -174,7 +177,7 @@ pub(crate) async fn set_room_visibility_route( } info!("{sender_user} made {0} public to the room directory", body.room_id); }, - room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id), _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -192,13 +195,13 @@ pub(crate) async fn set_room_visibility_route( pub(crate) async fn get_room_visibility_route( State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>, ) -> Result<get_room_visibility::v3::Response> { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } Ok(get_room_visibility::v3::Response { - visibility: if services.rooms.directory.is_public_room(&body.room_id)? { + visibility: if services.rooms.directory.is_public_room(&body.room_id).await { room::Visibility::Public } else { room::Visibility::Private @@ -257,101 +260,41 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = services + let mut all_rooms: Vec<PublicRoomsChunk> = services .rooms .directory .public_rooms() - .map(|room_id| { - let room_id = room_id?; - - let chunk = PublicRoomsChunk { - canonical_alias: services - .rooms - .state_accessor - .get_canonical_alias(&room_id)?, - name: services.rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services - .rooms - .state_cache - .room_joined_count(&room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - topic: services - .rooms - .state_accessor - .get_room_topic(&room_id) - .unwrap_or(None), - world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?, - guest_can_join: services - .rooms - .state_accessor - .guest_can_join(&room_id)?, - avatar_url: services - .rooms - .state_accessor - .get_avatar(&room_id)? - .into_option() - .unwrap_or_default() - .url, - join_rule: services - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, - }) - .map_err(|e| { - err!(Database(error!("Invalid room join rule event in database: {e}"))) - }) - }) - .transpose()? - .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services - .rooms - .state_accessor - .get_room_type(&room_id)?, - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { + .map(ToOwned::to_owned) + .then(|room_id| public_rooms_chunk(services, room_id)) + .filter_map(|chunk| async move { if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { if let Some(name) = &chunk.name { if name.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(topic) = &chunk.topic { if topic.to_lowercase().contains(&query) { - return true; + return Some(chunk); } } if let Some(canonical_alias) = &chunk.canonical_alias { if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + return Some(chunk); } } - false - } else { - // No search term - true + return None; } + + // No search term + Some(chunk) }) // We need to collect all, so we can sort by member count - .collect(); + .collect() + .await; all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); @@ -394,22 +337,23 @@ pub(crate) async fn get_public_rooms_filtered_helper( /// Check whether the user can publish to the room directory via power levels of /// room history visibility event or room creator -fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - if let Some(event) = services +async fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + if let Ok(event) = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "") + .await { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) .map(|content: RoomPowerLevelsEventContent| { RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) }) - } else if let Some(event) = - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? + } else if let Ok(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await { Ok(event.sender == user_id) } else { @@ -419,3 +363,62 @@ fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId )); } } + +async fn public_rooms_chunk(services: &Services, room_id: OwnedRoomId) -> PublicRoomsChunk { + PublicRoomsChunk { + canonical_alias: services + .rooms + .state_accessor + .get_canonical_alias(&room_id) + .await + .ok(), + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), + num_joined_members: services + .rooms + .state_cache + .room_joined_count(&room_id) + .await + .unwrap_or(0) + .try_into() + .expect("joined count overflows ruma UInt"), + topic: services + .rooms + .state_accessor + .get_room_topic(&room_id) + .await + .ok(), + world_readable: services + .rooms + .state_accessor + .is_world_readable(&room_id) + .await, + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + avatar_url: services + .rooms + .state_accessor + .get_avatar(&room_id) + .await + .into_option() + .unwrap_or_default() + .url, + join_rule: services + .rooms + .state_accessor + .room_state_get_content(&room_id, &StateEventType::RoomJoinRules, "") + .map_ok(|c: RoomJoinRulesEventContent| match c.join_rule { + JoinRule::Public => PublicRoomJoinRule::Public, + JoinRule::Knock => "knock".into(), + JoinRule::KnockRestricted(_) => "knock_restricted".into(), + _ => "invite".into(), + }) + .await + .unwrap_or_default(), + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_id, + } +} diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 8b2690c698c525ee3d8c37e7a6f27f874fde6914..2a8ebb9c22b7ee5aa04e9e880246702faaf0389c 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -1,10 +1,8 @@ use axum::extract::State; -use ruma::api::client::{ - error::ErrorKind, - filter::{create_filter, get_filter}, -}; +use conduit::err; +use ruma::api::client::filter::{create_filter, get_filter}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// @@ -15,11 +13,13 @@ pub(crate) async fn get_filter_route( State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>, ) -> Result<get_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); - }; - Ok(get_filter::v3::Response::new(filter)) + services + .users + .get_filter(sender_user, &body.filter_id) + .await + .map(get_filter::v3::Response::new) + .map_err(|_| err!(Request(NotFound("Filter not found.")))) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` @@ -29,7 +29,8 @@ pub(crate) async fn create_filter_route( State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>, ) -> Result<create_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(create_filter::v3::Response::new( - services.users.create_filter(sender_user, &body.filter)?, - )) + + let filter_id = services.users.create_filter(sender_user, &body.filter); + + Ok(create_filter::v3::Response::new(filter_id)) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index a426364a29b3b168b471e7d1c55af1e99241ecb1..53ec12f92e76ed3611615e4ff6dc0e8d1ef53799 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -4,8 +4,8 @@ }; use axum::extract::State; -use conduit::{utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{err, utils, utils::math::continue_exponential_backoff_secs, Err, Error, Result}; +use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::{ @@ -16,12 +16,15 @@ federation, }, serde::Raw, - DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, + OneTimeKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; use super::SESSION_ID_LENGTH; -use crate::{service::Services, Ruma}; +use crate::{ + service::{users::parse_master_key, Services}, + Ruma, +}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -33,13 +36,13 @@ pub(crate) async fn upload_keys_route( State(services): State<crate::State>, body: Ruma<upload_keys::v3::Request>, ) -> Result<upload_keys::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let (sender_user, sender_device) = body.sender(); - for (key_key, key_value) in &body.one_time_keys { + for (key_id, one_time_key) in &body.one_time_keys { services .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; + .add_one_time_key(sender_user, sender_device, key_id, one_time_key) + .await?; } if let Some(device_keys) = &body.device_keys { @@ -47,19 +50,22 @@ pub(crate) async fn upload_keys_route( // This check is needed to assure that signatures are kept if services .users - .get_device_keys(sender_user, sender_device)? - .is_none() + .get_device_keys(sender_user, sender_device) + .await + .is_err() { services .users - .add_device_keys(sender_user, sender_device, device_keys)?; + .add_device_keys(sender_user, sender_device, device_keys) + .await; } } Ok(upload_keys::v3::Response { one_time_key_counts: services .users - .count_one_time_keys(sender_user, sender_device)?, + .count_one_time_keys(sender_user, sender_device) + .await, }) } @@ -120,7 +126,9 @@ pub(crate) async fn upload_signing_keys_route( if let Some(auth) = &body.auth { let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; + .try_auth(sender_user, sender_device, auth, &uiaainfo) + .await?; + if !worked { return Err(Error::Uiaa(uiaainfo)); } @@ -129,20 +137,24 @@ pub(crate) async fn upload_signing_keys_route( uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; + .create(sender_user, sender_device, &uiaainfo, &json); + return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } if let Some(master_key) = &body.master_key { - services.users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; + services + .users + .add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, // notify so that other users see the new keys + ) + .await?; } Ok(upload_signing_keys::v3::Response {}) @@ -179,9 +191,11 @@ pub(crate) async fn upload_signatures_route( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .to_owned(), ); + services .users - .sign_key(user_id, key_id, signature, sender_user)?; + .sign_key(user_id, key_id, signature, sender_user) + .await?; } } } @@ -204,56 +218,52 @@ pub(crate) async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); + let from = body + .from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?; + + let to = body + .to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?; + device_list_updates.extend( services .users - .keys_changed( - sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .keys_changed(sender_user, from, Some(to)) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, ); - for room_id in services - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(Result::ok) - { + let mut rooms_joined = services.rooms.state_cache.rooms_joined(sender_user).boxed(); + + while let Some(room_id) = rooms_joined.next().await { device_list_updates.extend( services .users - .keys_changed( - room_id.as_ref(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(Result::ok), + .room_keys_changed(room_id, from, Some(to)) + .map(|(user_id, _)| user_id) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, ); } + Ok(get_key_changes::v3::Response { changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }) } -pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( +pub(crate) async fn get_keys_helper<F>( services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F, include_display_names: bool, -) -> Result<get_keys::v3::Response> { +) -> Result<get_keys::v3::Response> +where + F: Fn(&UserId) -> bool + Send + Sync, +{ let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new(); @@ -274,56 +284,60 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in services.users.all_device_ids(user_id) { - let device_id = device_id?; - if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { + let mut devices = services.users.all_device_ids(user_id).boxed(); + + while let Some(device_id) = devices.next().await { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, &device_id)? - .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Database("all_device_keys contained nonexistent device.")))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; - container.insert(device_id, keys); + container.insert(device_id.to_owned(), keys); } } + device_keys.insert(user_id.to_owned(), container); } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { + if let Ok(mut keys) = services.users.get_device_keys(user_id, device_id).await { let metadata = services .users - .get_device_metadata(user_id, device_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to get keys for nonexistent device.", - ))?; + .get_device_metadata(user_id, device_id) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to get keys for nonexistent device."))))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + .map_err(|_| err!(Database("invalid device keys in database")))?; + container.insert(device_id.to_owned(), keys); } + device_keys.insert(user_id.to_owned(), container); } } - if let Some(master_key) = services + if let Ok(master_key) = services .users - .get_master_key(sender_user, user_id, &allowed_signatures)? + .get_master_key(sender_user, user_id, &allowed_signatures) + .await { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = - services - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? + if let Ok(self_signing_key) = services + .users + .get_self_signing_key(sender_user, user_id, &allowed_signatures) + .await { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { + if let Ok(user_signing_key) = services.users.get_user_signing_key(user_id).await { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -385,24 +399,27 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { - for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; + for (user, master_key) in response.master_keys { + let (master_key_id, mut master_key) = parse_master_key(&user, &master_key)?; - if let Some(our_master_key) = - services - .users - .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + if let Ok(our_master_key) = services + .users + .get_key(&master_key_id, sender_user, &user, &allowed_signatures) + .await { - let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; - master_key.signatures.extend(our_master_key.signatures); + let (_, mut our_master_key) = parse_master_key(&user, &our_master_key)?; + master_key.signatures.append(&mut our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services.users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, /* Dont notify. A notification would trigger another key request resulting in an - * endless loop */ - )?; + services + .users + .add_cross_signing_keys( + &user, &raw, &None, &None, + false, /* Dont notify. A notification would trigger another key request resulting in an + * endless loop */ + ) + .await?; master_keys.insert(user.clone(), raw); } @@ -449,7 +466,7 @@ fn add_unsigned_device_display_name( } pub(crate) async fn claim_keys_helper( - services: &Services, one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, + services: &Services, one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, OneTimeKeyAlgorithm>>, ) -> Result<claim_keys::v3::Response> { let mut one_time_keys = BTreeMap::new(); @@ -465,9 +482,10 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services + if let Ok(one_time_keys) = services .users - .take_one_time_key(user_id, device_id, key_algorithm)? + .take_one_time_key(user_id, device_id, key_algorithm) + .await { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 12012711642ba1c3800164125cabe98a9509bcae..716936184ec26d05e7cc611e8fae787ff5ac0d1c 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -11,6 +11,7 @@ media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN, MXC_LENGTH}, Services, }; +use reqwest::Url; use ruma::{ api::client::{ authenticated_media::{ @@ -165,23 +166,33 @@ pub(crate) async fn get_media_preview_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|error| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to fetch URL preview.") - ))) - })?; + let preview = services + .media + .get_url_preview(&url) + .await + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to fetch URL preview: {error}") + ))) + })?; serde_json::value::to_raw_value(&preview) .map(get_media_preview::v1::Response::from_raw_value) .map_err(|error| { err!(Request(Unknown( - debug_error!(%sender_user, %url, ?error, "Failed to parse URL preview.") + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") ))) }) } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index e87b9a2b2c9629d943b6fae6b9de6aac8c4739ab..f6837462e9da9177fce98c43c8a1e7c316458ac7 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -8,6 +8,7 @@ Err, Result, }; use conduit_service::media::{Dim, FileMeta, CACHE_CONTROL_IMMUTABLE, CORP_CROSS_ORIGIN}; +use reqwest::Url; use ruma::{ api::client::media::{ create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, @@ -55,25 +56,31 @@ pub(crate) async fn get_media_preview_legacy_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !services.media.url_preview_allowed(url) { + let url = Url::parse(&body.url).map_err(|e| { + err!(Request(InvalidParam( + debug_warn!(%sender_user, %url, "Requested URL is not valid: {e}") + ))) + })?; + + if !services.media.url_preview_allowed(&url) { return Err!(Request(Forbidden( debug_warn!(%sender_user, %url, "URL is not allowed to be previewed") ))); } - let preview = services.media.get_url_preview(url).await.map_err(|e| { + let preview = services.media.get_url_preview(&url).await.map_err(|e| { err!(Request(Unknown( debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}") ))) })?; - let res = serde_json::value::to_raw_value(&preview).map_err(|e| { - err!(Request(Unknown( - debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}") - ))) - })?; - - Ok(get_media_preview::v3::Response::from_raw_value(res)) + serde_json::value::to_raw_value(&preview) + .map(get_media_preview::v3::Response::from_raw_value) + .map_err(|error| { + err!(Request(Unknown( + debug_error!(%sender_user, %url, "Failed to parse URL preview: {error}") + ))) + }) } /// # `GET /_matrix/media/v1/preview_url` diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 470db669369fc14421a392b7562a649acaf3c5f1..9478e383defa6aebf966acb81bafa087854a0b2b 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1,19 +1,20 @@ use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, net::IpAddr, sync::Arc, - time::Instant, }; use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{ - debug, debug_error, debug_warn, err, error, info, + debug, debug_info, debug_warn, err, error, info, pdu, pdu::{gen_event_id_canonical_json, PduBuilder}, + result::FlatOk, trace, utils, - utils::math::continue_exponential_backoff_secs, + utils::{shuffle, IterStream, ReadyExt}, warn, Err, Error, PduEvent, Result, }; +use futures::{FutureExt, StreamExt}; use ruma::{ api::{ client::{ @@ -33,15 +34,16 @@ member::{MembershipState, RoomMemberEventContent}, message::RoomMessageEventContent, }, - StateEventType, TimelineEventType, + StateEventType, }, - serde::Base64, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, - OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, + state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + RoomVersionId, ServerName, UserId, +}; +use service::{ + appservice::RegistrationInfo, + rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, + Services, }; -use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::{appservice::RegistrationInfo, rooms::state::RoomMutexGuard, Services}; -use tokio::sync::RwLock; use crate::{client::full_user_deactivate, Ruma}; @@ -55,9 +57,9 @@ async fn banned_room_check( services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, client_ip: IpAddr, ) -> Result<()> { - if !services.users.is_admin(user_id)? { + if !services.users.is_admin(user_id).await { if let Some(room_id) = room_id { - if services.rooms.metadata.is_banned(room_id)? + if services.rooms.metadata.is_banned(room_id).await || services .globals .config @@ -79,23 +81,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } } else if let Some(server_name) = server_name { if services @@ -119,23 +120,22 @@ async fn banned_room_check( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ {client_ip}" ))) - .await; + .await + .ok(); } let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - full_user_deactivate(services, user_id, all_joined_rooms).await?; + full_user_deactivate(services, user_id, &all_joined_rooms).await?; } - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This remote server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("This remote server is banned on this homeserver."))); } } } @@ -168,24 +168,24 @@ pub(crate) async fn join_room_by_id_route( .await?; // There is no body.server_name for /roomId/join - let mut servers = services + let mut servers: Vec<_> = services .rooms .state_cache .servers_invite_via(&body.room_id) - .filter_map(Result::ok) - .collect::<Vec<_>>(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( services .rooms .state_cache - .invite_state(sender_user, &body.room_id)? + .invite_state(sender_user, &body.room_id) + .await .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -193,6 +193,10 @@ pub(crate) async fn join_room_by_id_route( servers.push(server.into()); } + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + join_room_by_id_helper( &services, sender_user, @@ -202,6 +206,7 @@ pub(crate) async fn join_room_by_id_route( body.third_party_signed.as_ref(), &body.appservice_info, ) + .boxed() .await } @@ -233,20 +238,21 @@ pub(crate) async fn join_room_by_id_or_alias_route( .rooms .state_cache .servers_invite_via(&room_id) - .filter_map(Result::ok), + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, ); servers.extend( services .rooms .state_cache - .invite_state(sender_user, &room_id)? + .invite_state(sender_user, &room_id) + .await .unwrap_or_default() .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -254,43 +260,48 @@ pub(crate) async fn join_room_by_id_or_alias_route( servers.push(server.to_owned()); } + servers.sort_unstable(); + servers.dedup(); + shuffle(&mut servers); + (servers, room_id) }, Err(room_alias) => { - let response = services + let (room_id, mut servers) = services .rooms .alias - .resolve_alias(&room_alias, Some(&body.via.clone())) + .resolve_alias(&room_alias, Some(body.via.clone())) .await?; - let (room_id, mut pre_servers) = response; banned_room_check(&services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; - let mut servers = body.via; - if let Some(pre_servers) = &mut pre_servers { - servers.append(pre_servers); - } - servers.extend( - services - .rooms - .state_cache - .servers_invite_via(&room_id) - .filter_map(Result::ok), - ); + let addl_via_servers = services + .rooms + .state_cache + .servers_invite_via(&room_id) + .map(ToOwned::to_owned); - servers.extend( - services - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let addl_state_servers = services + .rooms + .state_cache + .invite_state(sender_user, &room_id) + .await + .unwrap_or_default(); + + let mut addl_servers: Vec<_> = addl_state_servers + .iter() + .map(|event| event.get_field("sender")) + .filter_map(FlatOk::flat_ok) + .map(|user: &UserId| user.server_name().to_owned()) + .stream() + .chain(addl_via_servers) + .collect() + .await; + + addl_servers.sort_unstable(); + addl_servers.dedup(); + shuffle(&mut addl_servers); + servers.append(&mut addl_servers); (servers, room_id) }, @@ -305,6 +316,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( body.third_party_signed.as_ref(), appservice_info, ) + .boxed() .await?; Ok(join_room_by_id_or_alias::v3::Response { @@ -337,7 +349,7 @@ pub(crate) async fn invite_user_route( ) -> Result<invite_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!( "User {sender_user} is not an admin and attempted to send an invite to room {}", &body.room_id @@ -354,7 +366,17 @@ pub(crate) async fn invite_user_route( user_id, } = &body.recipient { - invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; + if services.users.user_is_ignored(sender_user, user_id).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(user_id, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + return Ok(invite_user::v3::Response {}); + } + + invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false) + .boxed() + .await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -371,35 +393,25 @@ pub(crate) async fn kick_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let mut event: RoomMemberEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - event.reason.clone_from(&body.reason); + let event: RoomMemberEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot kick member that's not in the room."))))?; services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + body.user_id.to_string(), + &RoomMemberEventContent { + membership: MembershipState::Leave, + reason: body.reason.clone(), + ..event + }, + ), sender_user, &body.room_id, &state_lock, @@ -421,48 +433,35 @@ pub(crate) async fn ban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + let blurhash = services.users.blurhash(&body.user_id).await.ok(); + let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .map_or( - Ok(RoomMemberEventContent { + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_or_else( + |_| RoomMemberEventContent { + blurhash: blurhash.clone(), + reason: body.reason.clone(), + ..RoomMemberEventContent::new(MembershipState::Ban) + }, + |event| RoomMemberEventContent { membership: MembershipState::Ban, displayname: None, avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: blurhash.clone(), reason: body.reason.clone(), join_authorized_via_users_server: None, - }), - |event| { - serde_json::from_str(event.content.get()) - .map(|event: RoomMemberEventContent| RoomMemberEventContent { - membership: MembershipState::Ban, - displayname: None, - avatar_url: None, - blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), - reason: body.reason.clone(), - join_authorized_via_users_server: None, - ..event - }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) + ..event }, - )?; + ); services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(body.user_id.to_string(), &event), sender_user, &body.room_id, &state_lock, @@ -484,33 +483,26 @@ pub(crate) async fn unban_user_route( let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let mut event: RoomMemberEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? - .ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - event.reason.clone_from(&body.reason); - event.join_authorized_via_users_server = None; + let event: RoomMemberEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref()) + .await + .map_err(|_| err!(Request(BadState("Cannot unban a user who is not banned."))))?; services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + body.user_id.to_string(), + &RoomMemberEventContent { + membership: MembershipState::Leave, + reason: body.reason.clone(), + join_authorized_via_users_server: None, + ..event + }, + ), sender_user, &body.room_id, &state_lock, @@ -539,18 +531,16 @@ pub(crate) async fn forget_room_route( if services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "You must leave the room before forgetting it", - )); + return Err!(Request(Unknown("You must leave the room before forgetting it"))); } services .rooms .state_cache - .forget(&body.room_id, sender_user)?; + .forget(&body.room_id, sender_user); Ok(forget_room::v3::Response::new()) } @@ -568,8 +558,9 @@ pub(crate) async fn joined_rooms_route( .rooms .state_cache .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect(), + .map(ToOwned::to_owned) + .collect() + .await, }) } @@ -587,12 +578,10 @@ pub(crate) async fn get_member_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } Ok(get_member_events::v3::Response { @@ -622,30 +611,28 @@ pub(crate) async fn joined_members_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); + return Err!(Request(Forbidden("You don't have permission to view this room."))); } let joined: BTreeMap<OwnedUserId, RoomMember> = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(|user| { - let user = user.ok()?; - - Some(( + .map(ToOwned::to_owned) + .then(|user| async move { + ( user.clone(), RoomMember { - display_name: services.users.displayname(&user).unwrap_or_default(), - avatar_url: services.users.avatar_url(&user).unwrap_or_default(), + display_name: services.users.displayname(&user).await.ok(), + avatar_url: services.users.avatar_url(&user).await.ok(), }, - )) + ) }) - .collect(); + .collect() + .await; Ok(joined_members::v3::Response { joined, @@ -658,58 +645,77 @@ pub async fn join_room_by_id_helper( ) -> Result<join_room_by_id::v3::Response> { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let user_is_guest = services.users.is_deactivated(sender_user).unwrap_or(false) && appservice_info.is_none(); + let user_is_guest = services + .users + .is_deactivated(sender_user) + .await + .unwrap_or(false) + && appservice_info.is_none(); - if matches!(services.rooms.state_accessor.guest_can_join(room_id), Ok(false)) && user_is_guest { + if user_is_guest && !services.rooms.state_accessor.guest_can_join(room_id).await { return Err!(Request(Forbidden("Guests are not allowed to join this room"))); } - if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { + if services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), }); } - if services + let server_in_room = services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? - || servers.is_empty() - || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) - { + .server_in_room(services.globals.server_name(), room_id) + .await; + + let local_join = + server_in_room || servers.is_empty() || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])); + + if local_join { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) - .await + .boxed() + .await?; } else { // Ask a remote server if we are not participating in this room join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) - .await + .boxed() + .await?; } + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, -) -> Result<join_room_by_id::v3::Response> { +) -> Result { info!("Joining {room_id} over federation."); let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; info!("make_join finished"); - let room_version_id = match make_join_response.room_version { - Some(room_version) - if services - .globals - .supported_room_versions() - .contains(&room_version) => - { - room_version - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_join_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_join event json received from server: {e:?}")))?; @@ -738,14 +744,12 @@ async fn join_room_by_id_helper_remote( join_event_stub.insert( "content".to_owned(), to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server: join_authorized_via_users_server.clone(), + ..RoomMemberEventContent::new(MembershipState::Join) }) .expect("event is valid, we just created it"), ); @@ -761,42 +765,33 @@ async fn join_room_by_id_helper_remote( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back - join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let mut join_event = join_event_stub; info!("Asking {remote_server} for send_join in room {room_id}"); + let send_join_request = federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.clone(), + omit_members: false, + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()) + .await, + }; + let send_join_response = services .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: services - .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) + .send_synapse_request(&remote_server, send_join_request) .await?; info!("send_join finished"); @@ -814,7 +809,7 @@ async fn join_room_by_id_helper_remote( // validate and send signatures _ => { if let Some(signed_raw) = &send_join_response.room_state.event { - info!( + debug_info!( "There is a signed event. This room is probably using restricted joins. Adding signature to \ our event" ); @@ -864,116 +859,145 @@ async fn join_room_by_id_helper_remote( } } - services.rooms.short.get_or_create_shortroomid(room_id)?; + services + .rooms + .short + .get_or_create_shortroomid(room_id) + .await; info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + let parsed_join_pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|e| err!(BadServerResponse("Invalid join event PDU: {e:?}")))?; - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - info!("Fetching join signing keys"); + info!("Acquiring server signing keys for response events"); + let resp_events = &send_join_response.room_state; + let resp_state = &resp_events.state; + let resp_auth = &resp_events.auth_chain; services .server_keys - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; + .acquire_events_pubkeys(resp_auth.iter().chain(resp_state.iter())) + .await; info!("Going through send_join response room_state"); - for result in send_join_response + let cork = services.db.cork_and_flush(); + let state = send_join_response .room_state .state .iter() - .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) - { - let Ok((event_id, value)) = result.await else { - continue; - }; + .stream() + .then(|pdu| { + services + .server_keys + .validate_and_add_event_id_no_fetch(pdu, &room_version_id) + }) + .ready_filter_map(Result::ok) + .fold(HashMap::new(), |mut state, (event_id, value)| async move { + let pdu = match PduEvent::from_id_val(&event_id, value.clone()) { + Ok(pdu) => pdu, + Err(e) => { + debug_warn!("Invalid PDU in send_join response: {e:?}: {value:#?}"); + return state; + }, + }; - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - debug_warn!("Invalid PDU in send_join response: {value:#?}"); - err!(BadServerResponse("Invalid PDU in send_join response: {e:?}")) - })?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value); + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key) + .await; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } + state.insert(shortstatekey, pdu.event_id.clone()); + } + + state + }) + .await; + + drop(cork); info!("Going through send_join response auth_chain"); - for result in send_join_response + let cork = services.db.cork_and_flush(); + send_join_response .room_state .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) - { - let Ok((event_id, value)) = result.await else { - continue; - }; + .stream() + .then(|pdu| { + services + .server_keys + .validate_and_add_event_id_no_fetch(pdu, &room_version_id) + }) + .ready_filter_map(Result::ok) + .ready_for_each(|(event_id, value)| services.rooms.outlier.add_pdu_outlier(&event_id, &value)) + .await; - services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; - } + drop(cork); debug!("Running send_join auth check"); + let fetch_state = &state; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = fetch_state.get(&shortstatekey)?; + services.rooms.timeline.get_pdu(event_id).await.ok() + }; let auth_check = state_res::event_auth::auth_check( &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), &parsed_join_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| { - services - .rooms - .timeline - .get_pdu( - state.get( - &services - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + return Err!(Request(Forbidden("Auth check failed"))); } - info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) - .collect::<Result<_>>()?, - ), - )?; + info!("Compressing state from send_join"); + let compressed = state + .iter() + .stream() + .then(|(&k, id)| services.rooms.state_compressor.compress_state_event(k, id)) + .collect() + .await; + + debug!("Saving compressed state"); + let HashSetCompressStateEvent { + shortstatehash: statehash_before_join, + added, + removed, + } = services + .rooms + .state_compressor + .save_state(room_id, Arc::new(compressed)) + .await?; + debug!("Forcing state for new room"); services .rooms .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .force_state(room_id, statehash_before_join, added, removed, &state_lock) .await?; info!("Updating joined counts for new room"); - services.rooms.state_cache.update_joined_count(room_id)?; + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services + .rooms + .state + .append_to_state(&parsed_join_pdu) + .await?; info!("Appending new room join event"); services @@ -993,35 +1017,27 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; + .set_room_state(room_id, statehash_after_join, &state_lock); - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + Ok(()) } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, -) -> Result<join_room_by_id::v3::Response> { +) -> Result { debug!("We can join locally"); - let join_rules_event = services + let join_rules_event_content = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|content: RoomJoinRulesEventContent| content); let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { + Ok(RoomJoinRulesEventContent { join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), }) => restricted .allow @@ -1034,29 +1050,34 @@ async fn join_room_by_id_helper_local( _ => Vec::new(), }; - let local_members = services + let local_members: Vec<_> = services .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|user| services.globals.user_is_local(user)) - .collect::<Vec<OwnedUserId>>(); + .ready_filter(|user| services.globals.user_is_local(user)) + .map(ToOwned::to_owned) + .collect() + .await; let mut join_authorized_via_users_server: Option<OwnedUserId> = None; - if restriction_rooms.iter().any(|restriction_room_id| { - services - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { + if restriction_rooms + .iter() + .stream() + .any(|restriction_room_id| { + services + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + }) + .await + { for user in local_members { if services .rooms .state_accessor .user_can_invite(room_id, &user, sender_user, &state_lock) - .unwrap_or(false) + .await { join_authorized_via_users_server = Some(user); break; @@ -1064,15 +1085,13 @@ async fn join_room_by_id_helper_local( } } - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + let content = RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), reason: reason.clone(), join_authorized_via_users_server, + ..RoomMemberEventContent::new(MembershipState::Join) }; // Try normal join first @@ -1080,21 +1099,14 @@ async fn join_room_by_id_helper_local( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(sender_user.to_string(), &content), sender_user, room_id, &state_lock, ) .await { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Ok(_) => return Ok(()), Err(e) => e, }; @@ -1106,17 +1118,20 @@ async fn join_room_by_id_helper_local( warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_join_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_join event json received from server: {e:?}")))?; let join_authorized_via_users_server = join_event_stub @@ -1143,14 +1158,12 @@ async fn join_room_by_id_helper_local( join_event_stub.insert( "content".to_owned(), to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), reason, join_authorized_via_users_server, + ..RoomMemberEventContent::new(MembershipState::Join) }) .expect("event is valid, we just created it"), ); @@ -1166,39 +1179,31 @@ async fn join_room_by_id_helper_local( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut join_event_stub, &room_version_id)?; // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&join_event_stub, &room_version_id)?; // Add event_id back - join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let join_event = join_event_stub; let send_join_response = services .sending - .send_federation_request( + .send_synapse_request( &remote_server, federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), - event_id: event_id.to_owned(), + event_id: event_id.clone(), + omit_members: false, pdu: services .sending - .convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, + .convert_to_outgoing_federation_event(join_event.clone()) + .await, }, ) .await?; @@ -1220,15 +1225,10 @@ async fn join_room_by_id_helper_local( } drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services - .server_keys - .fetch_required_signing_keys([&signed_value], &pub_key_map) - .await?; services .rooms .event_handler - .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true, &pub_key_map) + .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true) .await?; } else { return Err(error); @@ -1237,7 +1237,7 @@ async fn join_room_by_id_helper_local( return Err(error); } - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + Ok(()) } async fn make_join_request( @@ -1245,8 +1245,8 @@ async fn make_join_request( ) -> Result<(federation::membership::prepare_join_event::v1::Response, OwnedServerName)> { let mut make_join_response_and_server = Err!(BadServerResponse("No server available to assist in joining.")); - let mut make_join_counter: u16 = 0; - let mut incompatible_room_version_count: u8 = 0; + let mut make_join_counter: usize = 0; + let mut incompatible_room_version_count: usize = 0; for remote_server in servers { if services.globals.server_is_ours(remote_server) { @@ -1269,28 +1269,25 @@ async fn make_join_request( make_join_counter = make_join_counter.saturating_add(1); if let Err(ref e) = make_join_response { - trace!("make_join ErrorKind string: {:?}", e.kind().to_string()); - - // converting to a string is necessary (i think) because ruma is forcing us to - // fill in the struct for M_INCOMPATIBLE_ROOM_VERSION - if e.kind().to_string().contains("M_INCOMPATIBLE_ROOM_VERSION") - || e.kind().to_string().contains("M_UNSUPPORTED_ROOM_VERSION") - { + if matches!( + e.kind(), + ErrorKind::IncompatibleRoomVersion { .. } | ErrorKind::UnsupportedRoomVersion + ) { incompatible_room_version_count = incompatible_room_version_count.saturating_add(1); } if incompatible_room_version_count > 15 { info!( "15 servers have responded with M_INCOMPATIBLE_ROOM_VERSION or M_UNSUPPORTED_ROOM_VERSION, \ - assuming that Conduwuit does not support the room {room_id}: {e}" + assuming that conduwuit does not support the room version {room_id}: {e}" ); make_join_response_and_server = Err!(BadServerResponse("Room version is not supported by Conduwuit")); return make_join_response_and_server; } - if make_join_counter > 50 { + if make_join_counter > 40 { warn!( - "50 servers failed to provide valid make_join response, assuming no server can assist in joining." + "40 servers failed to provide valid make_join response, assuming no server can assist in joining." ); make_join_response_and_server = Err!(BadServerResponse("No server available to assist in joining.")); return make_join_response_and_server; @@ -1307,69 +1304,11 @@ async fn make_join_request( make_join_response_and_server } -pub async fn validate_and_add_event_id( - services: &Services, pdu: &RawJsonValue, room_version: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_error!("Invalid PDU in server response: {pdu:#?}"); - err!(BadServerResponse("Invalid PDU in server response: {e:?}")) - })?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - let back_off = |id| async { - match services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry(id) - { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); - }, - } - }; - - if let Some((time, tries)) = services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&event_id) - { - // Exponential backoff - const MIN: u64 = 60 * 5; - const MAX: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { - return Err!(BadServerResponse("bad event {event_id:?}, still backing off")); - } - } - - if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) { - debug_error!("Event {event_id} failed verification {pdu:#?}"); - let e = Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); - back_off(event_id).await; - return e; - } - - value.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - - Ok((event_id, value)) -} - pub(crate) async fn invite_helper( services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>, is_direct: bool, ) -> Result<()> { - if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user).await && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -1380,40 +1319,33 @@ pub(crate) async fn invite_helper( if !services.globals.user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services.users.avatar_url(user_id)?, - displayname: None, + + let content = RoomMemberEventContent { + avatar_url: services.users.avatar_url(user_id).await.ok(), is_direct: Some(is_direct), - membership: MembershipState::Invite, - third_party_invite: None, - blurhash: None, reason, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); - - let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - room_id, - &state_lock, - )?; + ..RoomMemberEventContent::new(MembershipState::Invite) + }; + + let (pdu, pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder::state(user_id.to_string(), &content), + sender_user, + room_id, + &state_lock, + ) + .await?; - let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services.rooms.state.summary_stripped(&pdu).await; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let response = services .sending @@ -1425,15 +1357,19 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: services .sending - .convert_to_outgoing_federation_event(pdu_json.clone()), + .convert_to_outgoing_federation_event(pdu_json.clone()) + .await, invite_room_state, - via: services.rooms.state_cache.servers_route_via(room_id).ok(), + via: services + .rooms + .state_cache + .servers_route_via(room_id) + .await + .ok(), }, ) .await?; - let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and // hashes checks let Ok((event_id, value)) = gen_event_id_canonical_json(&response.event, &room_version_id) else { @@ -1446,10 +1382,8 @@ pub(crate) async fn invite_helper( if *pdu.event_id != *event_id { warn!( - "Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", + "Server {} changed invite event, that's not allowed in the spec: ours: {pdu_json:?}, theirs: {value:?}", user_id.server_name(), - pdu_json, - value ); } @@ -1463,26 +1397,23 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - - let pdu_id: Vec<u8> = services + let pdu_id = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) + .handle_incoming_pdu(&origin, room_id, &event_id, value, true) .await? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept incoming PDU as timeline event."))))?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); } - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { + if !services + .rooms + .state_cache + .is_joined(sender_user, room_id) + .await + { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -1491,28 +1422,20 @@ pub(crate) async fn invite_helper( let state_lock = services.rooms.state.mutex.lock(room_id).await; + let content = RoomMemberEventContent { + displayname: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), + blurhash: services.users.blurhash(user_id).await.ok(), + is_direct: Some(is_direct), + reason, + ..RoomMemberEventContent::new(MembershipState::Invite) + }; + services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: services.users.displayname(user_id)?, - avatar_url: services.users.avatar_url(user_id)?, - is_direct: Some(is_direct), - third_party_invite: None, - blurhash: services.users.blurhash(user_id)?, - reason, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(user_id.to_string(), &content), sender_user, room_id, &state_lock, @@ -1527,77 +1450,84 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { - let all_rooms = services + let rooms_joined = services .rooms .state_cache .rooms_joined(user_id) - .chain( - services - .rooms - .state_cache - .rooms_invited(user_id) - .map(|t| t.map(|(r, _)| r)), - ) - .collect::<Vec<_>>(); + .map(ToOwned::to_owned); - for room_id in all_rooms { - let Ok(room_id) = room_id else { - continue; - }; + let rooms_invited = services + .rooms + .state_cache + .rooms_invited(user_id) + .map(|(r, _)| r); + + let all_rooms: Vec<_> = rooms_joined.chain(rooms_invited).collect().await; + for room_id in all_rooms { // ignore errors if let Err(e) = leave_room(services, user_id, &room_id, None).await { warn!(%room_id, %user_id, %e, "Failed to leave room"); } - if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) { - warn!(%room_id, %user_id, %e, "Failed to forget room"); - } + + services.rooms.state_cache.forget(&room_id, user_id); } } pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> { + //use conduit::utils::stream::OptionStream; + use futures::TryFutureExt; + // Ask a remote server if we don't have this room if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id) + .await { if let Err(e) = remote_leave_room(services, user_id, room_id).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); + warn!(%user_id, "Failed to leave room {room_id} remotely: {e}"); // Don't tell the client about this error } let last_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; + .invite_state(user_id, room_id) + .map_err(|_| services.rooms.state_cache.left_state(user_id, room_id)) + .await + .ok(); // We always drop the invite, we can't rely on other servers - services.rooms.state_cache.update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - last_state, - None, - true, - )?; + services + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + last_state, + None, + true, + ) + .await?; } else { let state_lock = services.rooms.state.mutex.lock(room_id).await; - let member_event = + let Ok(event) = services + .rooms + .state_accessor + .room_state_get_content::<RoomMemberEventContent>(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + // Fix for broken rooms + error!("Trying to leave a room you are not a member of."); + services .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; - - // Fix for broken rooms - let member_event = match member_event { - None => { - error!("Trying to leave a room you are not a member of."); - - services.rooms.state_cache.update_membership( + .state_cache + .update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1605,32 +1535,24 @@ pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, None, None, true, - )?; - return Ok(()); - }, - Some(e) => e, - }; - - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { - error!("Invalid room member event in database: {}", e); - Error::bad_database("Invalid member event in database.") - })?; + ) + .await?; - event.membership = MembershipState::Leave; - event.reason = reason; + return Ok(()); + }; services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { + membership: MembershipState::Leave, + reason, + ..event + }, + ), user_id, room_id, &state_lock, @@ -1647,23 +1569,23 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let invite_state = services .rooms .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; + .invite_state(user_id, room_id) + .await + .map_err(|_| err!(Request(BadState("User is not invited."))))?; let mut servers: HashSet<OwnedServerName> = services .rooms .state_cache .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.extend( invite_state .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) + .filter_map(|event| event.get_field("sender").ok().flatten()) + .filter_map(|sender: &str| UserId::parse(sender).ok()) .map(|user| user.server_name().to_owned()), ); @@ -1690,18 +1612,20 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room let (make_leave_response, remote_server) = make_leave_response_and_server?; - let room_version_id = match make_leave_response.room_version { - Some(version) - if services - .globals - .supported_room_versions() - .contains(&version) => - { - version - }, - _ => return Err!(BadServerResponse("Room version is not supported")), + let Some(room_version_id) = make_leave_response.room_version else { + return Err!(BadServerResponse("Remote room version is not supported by conduwuit")); }; + if !services + .globals + .supported_room_versions() + .contains(&room_version_id) + { + return Err!(BadServerResponse( + "Remote room version {room_version_id} is not supported by conduwuit" + )); + } + let mut leave_event_stub = serde_json::from_str::<CanonicalJsonObject>(make_leave_response.event.get()) .map_err(|e| err!(BadServerResponse("Invalid make_leave event json received from server: {e:?}")))?; @@ -1729,24 +1653,15 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + services + .server_keys + .hash_and_sign_event(&mut leave_event_stub, &room_version_id)?; // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + let event_id = pdu::gen_event_id(&leave_event_stub, &room_version_id)?; // Add event_id back - leave_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + leave_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.clone().into())); // It has enough fields to be called a proper event now let leave_event = leave_event_stub; @@ -1760,7 +1675,8 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room event_id, pdu: services .sending - .convert_to_outgoing_federation_event(leave_event.clone()), + .convert_to_outgoing_federation_event(leave_event.clone()) + .await, }, ) .await?; diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 51aee8c12f04391652b6c22c3ff4d5e3c541e907..88453de0cb70ee7537f5925d4316ff9fec3f349d 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,109 +1,52 @@ -use std::collections::{BTreeMap, HashSet}; +use std::collections::HashSet; use axum::extract::State; -use conduit::PduCount; -use ruma::{ - api::client::{ - error::ErrorKind, - filter::{RoomEventFilter, UrlFilter}, - message::{get_message_events, send_message_event}, +use conduit::{ + at, is_equal_to, + utils::{ + result::{FlatOk, LogErr}, + IterStream, ReadyExt, }, - events::{MessageLikeEventType, StateEventType}, - RoomId, UserId, + Event, PduCount, Result, }; -use serde_json::{from_str, Value}; - -use crate::{ - service::{pdu::PduBuilder, Services}, - utils, Error, PduEvent, Result, Ruma, +use futures::{FutureExt, StreamExt}; +use ruma::{ + api::{ + client::{filter::RoomEventFilter, message::get_message_events}, + Direction, + }, + events::{AnyStateEvent, StateEventType, TimelineEventType, TimelineEventType::*}, + serde::Raw, + DeviceId, OwnedUserId, RoomId, UserId, }; - -/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` -/// -/// Send a message event into the room. -/// -/// - Is a NOOP if the txn id was already used before and returns the same event -/// id again -/// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is -/// allowed -pub(crate) async fn send_message_event_route( - State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>, -) -> Result<send_message_event::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_deref(); - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - // Forbid m.room.encrypted if encryption is disabled - if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); - } - - if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Room call invites are not allowed in public rooms", - )); - } - - // Check if this is a new transaction id - if let Some(response) = services - .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - { - // The client might have sent a txnid of the /sendToDevice endpoint - // This txnid has no response associated with it - if response.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", - )); - } - - let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; - return Ok(send_message_event::v3::Response { - event_id, - }); - } - - let mut unsigned = BTreeMap::new(); - unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - - let event_id = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: body.event_type.to_string().into(), - content: from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, - unsigned: Some(unsigned), - state_key: None, - redacts: None, - timestamp: if body.appservice_info.is_some() { - body.timestamp - } else { - None - }, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - services - .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; - - drop(state_lock); - - Ok(send_message_event::v3::Response::new((*event_id).to_owned())) -} +use service::{rooms::timeline::PdusIterItem, Services}; + +use crate::Ruma; + +pub(crate) type LazySet = HashSet<OwnedUserId>; + +/// list of safe and common non-state events to ignore +const IGNORED_MESSAGE_TYPES: &[TimelineEventType] = &[ + RoomMessage, + Sticker, + CallInvite, + CallNotify, + RoomEncrypted, + Image, + File, + Audio, + Voice, + Video, + UnstablePollStart, + PollStart, + KeyVerificationStart, + Reaction, + Emote, + Location, +]; + +const LIMIT_MAX: usize = 100; +const LIMIT_DEFAULT: usize = 10; /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// @@ -114,169 +57,186 @@ pub(crate) async fn send_message_event_route( pub(crate) async fn get_message_events_route( State(services): State<crate::State>, body: Ruma<get_message_events::v3::Request>, ) -> Result<get_message_events::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; + let sender = body.sender(); + let (sender_user, sender_device) = sender; + let room_id = &body.room_id; + let filter = &body.filter; + + let from: PduCount = body + .from + .as_deref() + .map(str::parse) + .transpose()? + .unwrap_or_else(|| match body.dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); + + let to: Option<PduCount> = body.to.as_deref().map(str::parse).flat_ok(); - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let limit: usize = body + .limit + .try_into() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); services .rooms .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) - .await?; - - let limit = usize::try_from(body.limit).unwrap_or(10).min(100); - - let next_token; - - let mut resp = get_message_events::v3::Response::new(); - - let mut lazy_loaded = HashSet::new(); - - match body.dir { - ruma::api::Direction::Forward => { - let events_after: Vec<_> = services - .rooms - .timeline - .pdus_after(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) - - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .take(limit) - .collect(); - - for (_, event) in &events_after { - /* TODO: Remove the not "element_hacks" check when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - */ - if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { - lazy_loaded.insert(event.sender.clone()); - } - - lazy_loaded.insert(event.sender.clone()); - } - - next_token = events_after.last().map(|(count, _)| count).copied(); - - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_after; - }, - ruma::api::Direction::Backward => { - services - .rooms - .timeline - .backfill_if_required(&body.room_id, from) - .await?; - let events_before: Vec<_> = services - .rooms - .timeline - .pdus_until(sender_user, &body.room_id, from)? - .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .take(limit) - .collect(); - - for (_, event) in &events_before { - /* TODO: Remove the not "element_hacks" check when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - */ - if !cfg!(feature = "element_hacks") - && !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { - lazy_loaded.insert(event.sender.clone()); - } - - lazy_loaded.insert(event.sender.clone()); - } - - next_token = events_before.last().map(|(count, _)| count).copied(); - - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); - - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_before; - }, + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, from); + + if matches!(body.dir, Direction::Backward) { + services + .rooms + .timeline + .backfill_if_required(room_id, from) + .boxed() + .await + .log_err() + .ok(); } - resp.state = Vec::new(); - for ll_id in &lazy_loaded { - if let Some(member_event) = - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? - { - resp.state.push(member_event.to_state_event()); - } - } + let it = match body.dir { + Direction::Forward => services + .rooms + .timeline + .pdus(Some(sender_user), room_id, Some(from)) + .await? + .boxed(), + + Direction::Backward => services + .rooms + .timeline + .pdus_rev(Some(sender_user), room_id, Some(from)) + .await? + .boxed(), + }; + + let events: Vec<_> = it + .ready_take_while(|(count, _)| Some(*count) != to) + .ready_filter_map(|item| event_filter(item, filter)) + .filter_map(|item| ignored_filter(&services, item, sender_user)) + .filter_map(|item| visibility_filter(&services, item, sender_user)) + .take(limit) + .collect() + .await; + + let lazy = events + .iter() + .stream() + .fold(LazySet::new(), |lazy, item| { + update_lazy(&services, room_id, sender, lazy, item, false) + }) + .await; + + let state = lazy + .iter() + .stream() + .filter_map(|user_id| get_member_event(&services, room_id, user_id)) + .collect() + .await; + + let start_token = events.first().map(at!(0)).unwrap_or(from); + + let next_token = events.last().map(at!(0)); - // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { services .rooms .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) - .await; + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy, next_token); } } - Ok(resp) + let chunk = events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) + .collect(); + + Ok(get_message_events::v3::Response { + start: start_token.to_string(), + end: next_token.as_ref().map(ToString::to_string), + chunk, + state, + }) } -fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { +async fn get_member_event(services: &Services, room_id: &RoomId, user_id: &UserId) -> Option<Raw<AnyStateEvent>> { services .rooms .state_accessor - .user_can_see_event(user_id, room_id, &pdu.event_id) - .unwrap_or(false) + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + .map(|member_event| member_event.to_state_event()) + .ok() +} + +pub(crate) async fn update_lazy( + services: &Services, room_id: &RoomId, sender: (&UserId, &DeviceId), mut lazy: LazySet, item: &PdusIterItem, + force: bool, +) -> LazySet { + let (_, event) = &item; + let (sender_user, sender_device) = sender; + + /* TODO: Remove the not "element_hacks" check when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + */ + if force || cfg!(features = "element_hacks") { + lazy.insert(event.sender().into()); + return lazy; + } + + if lazy.contains(event.sender()) { + return lazy; + } + + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, event.sender()) + .await + { + lazy.insert(event.sender().into()); + } + + lazy } -fn contains_url_filter(pdu: &PduEvent, filter: &RoomEventFilter) -> bool { - if filter.url_filter.is_none() { - return true; +pub(crate) async fn ignored_filter(services: &Services, item: PdusIterItem, user_id: &UserId) -> Option<PdusIterItem> { + let (_, pdu) = &item; + + if pdu.kind.to_cow_str() == "org.matrix.dummy_event" { + return None; + } + + if !IGNORED_MESSAGE_TYPES.iter().any(is_equal_to!(&pdu.kind)) { + return Some(item); } - let content: Value = from_str(pdu.content.get()).unwrap(); - match filter.url_filter { - Some(UrlFilter::EventsWithoutUrl) => !content["url"].is_string(), - Some(UrlFilter::EventsWithUrl) => content["url"].is_string(), - None => true, + if !services.users.user_is_ignored(&pdu.sender, user_id).await { + return Some(item); } + + None +} + +pub(crate) async fn visibility_filter( + services: &Services, item: PdusIterItem, user_id: &UserId, +) -> Option<PdusIterItem> { + let (_, pdu) = &item; + + services + .rooms + .state_accessor + .user_can_see_event(user_id, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) +} + +pub(crate) fn event_filter(item: PdusIterItem, filter: &RoomEventFilter) -> Option<PdusIterItem> { + let (_, pdu) = &item; + pdu.matches(filter).then_some(item) } diff --git a/src/api/client/mod.rs b/src/api/client/mod.rs index 4b7b64b9180c2f05ec29cc13346a1f6f670ab598..9ee88bec19d1220a26162b08c4a98e37ea2988f5 100644 --- a/src/api/client/mod.rs +++ b/src/api/client/mod.rs @@ -23,6 +23,7 @@ pub(super) mod report; pub(super) mod room; pub(super) mod search; +pub(super) mod send; pub(super) mod session; pub(super) mod space; pub(super) mod state; @@ -52,7 +53,7 @@ pub(super) use media::*; pub(super) use media_legacy::*; pub(super) use membership::*; -pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room, validate_and_add_event_id}; +pub use membership::{join_room_by_id_helper, leave_all_rooms, leave_room}; pub(super) use message::*; pub(super) use openid::*; pub(super) use presence::*; @@ -65,6 +66,7 @@ pub(super) use report::*; pub(super) use room::*; pub(super) use search::*; +pub(super) use send::*; pub(super) use session::*; pub(super) use space::*; pub(super) use state::*; diff --git a/src/api/client/presence.rs b/src/api/client/presence.rs index 8384d5acae49f3ae78c2cc892ba2a0158c431ee2..ba48808bd566e154cdd1a9d581005c677c830f8f 100644 --- a/src/api/client/presence.rs +++ b/src/api/client/presence.rs @@ -28,7 +28,8 @@ pub(crate) async fn set_presence_route( services .presence - .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; + .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone()) + .await?; Ok(set_presence::v3::Response {}) } @@ -49,14 +50,15 @@ pub(crate) async fn get_presence_route( let mut presence_event = None; - for _room_id in services + let has_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - { - if let Some(presence) = services.presence.get_presence(&body.user_id)? { + .has_shared_rooms(sender_user, &body.user_id) + .await; + + if has_shared_rooms { + if let Ok(presence) = services.presence.get_presence(&body.user_id).await { presence_event = Some(presence); - break; } } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index bf47a3f858671ef6ebe6766e60b0418e2eed21b0..32f7a72363e43b94ae040e52564bd588e3678548 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,5 +1,10 @@ use axum::extract::State; -use conduit::{pdu::PduBuilder, warn, Err, Error, Result}; +use conduit::{ + pdu::PduBuilder, + utils::{stream::TryIgnore, IterStream}, + warn, Err, Error, Result, +}; +use futures::{StreamExt, TryStreamExt}; use ruma::{ api::{ client::{ @@ -8,11 +13,10 @@ }, federation, }, - events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, + events::{room::member::RoomMemberEventContent, StateEventType}, presence::PresenceState, OwnedMxcUri, OwnedRoomId, UserId, }; -use serde_json::value::to_raw_value; use service::Services; use crate::Ruma; @@ -35,16 +39,18 @@ pub(crate) async fn set_displayname_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; - update_displayname(&services, &body.user_id, body.displayname.clone(), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, body.displayname.clone(), &all_joined_rooms).await?; if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_display_name::v3::Response {}) @@ -72,22 +78,19 @@ pub(crate) async fn get_displayname_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_display_name::v3::Response { displayname: response.displayname, @@ -95,14 +98,14 @@ pub(crate) async fn get_displayname_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { - displayname: services.users.displayname(&body.user_id)?, + displayname: services.users.displayname(&body.user_id).await.ok(), }) } @@ -124,15 +127,16 @@ pub(crate) async fn set_avatar_url_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; update_avatar_url( &services, &body.user_id, body.avatar_url.clone(), body.blurhash.clone(), - all_joined_rooms, + &all_joined_rooms, ) .await?; @@ -140,7 +144,9 @@ pub(crate) async fn set_avatar_url_route( // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await + .ok(); } Ok(set_avatar_url::v3::Response {}) @@ -168,22 +174,21 @@ pub(crate) async fn get_avatar_url_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); return Ok(get_avatar_url::v3::Response { avatar_url: response.avatar_url, @@ -192,15 +197,15 @@ pub(crate) async fn get_avatar_url_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_avatar_url::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), }) } @@ -226,31 +231,30 @@ pub(crate) async fn get_profile_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); for (profile_key, profile_key_value) in &response.custom_profile_fields { services .users - .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, profile_key, Some(profile_key_value.clone())); } return Ok(get_profile::v3::Response { @@ -263,134 +267,108 @@ pub(crate) async fn get_profile_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_profile::v3::Response { - avatar_url: services.users.avatar_url(&body.user_id)?, - blurhash: services.users.blurhash(&body.user_id)?, - displayname: services.users.displayname(&body.user_id)?, - tz: services.users.timezone(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id).await.ok(), + blurhash: services.users.blurhash(&body.user_id).await.ok(), + displayname: services.users.displayname(&body.user_id).await.ok(), + tz: services.users.timezone(&body.user_id).await.ok(), custom_profile_fields: services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(), + .collect() + .await, }) } pub async fn update_displayname( - services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, + services: &Services, user_id: &UserId, displayname: Option<String>, all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_display_name = services.users.displayname(user_id).unwrap_or_default(); + let current_display_name = services.users.displayname(user_id).await.ok(); if displayname == current_display_name { return Ok(()); } - services - .users - .set_displayname(user_id, displayname.clone()) - .await?; + services.users.set_displayname(user_id, displayname.clone()); // Send a new join membership event into all joined rooms - let all_joined_rooms: Vec<_> = all_joined_rooms - .iter() - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - displayname: displayname.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { - Error::bad_database("Tried to send display name update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - room_id, - )) - }) - .filter_map(Result::ok) - .collect(); + let mut joined_rooms = Vec::new(); + for room_id in all_joined_rooms { + let Ok(content) = services + .rooms + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + else { + continue; + }; + + let pdu = PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { + displayname: displayname.clone(), + join_authorized_via_users_server: None, + ..content + }, + ); + + joined_rooms.push((pdu, room_id)); + } - update_all_rooms(services, all_joined_rooms, user_id).await; + update_all_rooms(services, joined_rooms, user_id).await; Ok(()) } pub async fn update_avatar_url( services: &Services, user_id: &UserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>, - all_joined_rooms: Vec<OwnedRoomId>, + all_joined_rooms: &[OwnedRoomId], ) -> Result<()> { - let current_avatar_url = services.users.avatar_url(user_id).unwrap_or_default(); - let current_blurhash = services.users.blurhash(user_id).unwrap_or_default(); + let current_avatar_url = services.users.avatar_url(user_id).await.ok(); + let current_blurhash = services.users.blurhash(user_id).await.ok(); if current_avatar_url == avatar_url && current_blurhash == blurhash { return Ok(()); } - services - .users - .set_avatar_url(user_id, avatar_url.clone()) - .await?; - services - .users - .set_blurhash(user_id, blurhash.clone()) - .await?; + services.users.set_avatar_url(user_id, avatar_url.clone()); + + services.users.set_blurhash(user_id, blurhash.clone()); // Send a new join membership event into all joined rooms + let avatar_url = &avatar_url; + let blurhash = &blurhash; let all_joined_rooms: Vec<_> = all_joined_rooms .iter() - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - avatar_url: avatar_url.clone(), - blurhash: blurhash.clone(), - join_authorized_via_users_server: None, - ..serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or_else(|| { - Error::bad_database("Tried to send avatar URL update for user not in the room.") - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, + .try_stream() + .and_then(|room_id: &OwnedRoomId| async move { + let content = services + .rooms + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await?; + + let pdu = PduBuilder::state( + user_id.to_string(), + &RoomMemberEventContent { + avatar_url: avatar_url.clone(), + blurhash: blurhash.clone(), + join_authorized_via_users_server: None, + ..content }, - room_id, - )) + ); + + Ok((pdu, room_id)) }) - .filter_map(Result::ok) - .collect(); + .ignore_err() + .collect() + .await; update_all_rooms(services, all_joined_rooms, user_id).await; diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 8723e676bcdf5d85d7fcef98f0de6871e603dbfb..97243ab451a913b6d0676d2109ef935f9398ef4b 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -1,19 +1,19 @@ use axum::extract::State; -use conduit::err; +use conduit::{err, Err}; use ruma::{ api::client::{ error::ErrorKind, push::{ delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all, - set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope, + get_pushrules_global_scope, set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, }, }, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, GlobalAccountDataEventType, }, - push::{InsertPushRuleError, RemovePushRuleError, Ruleset}, - CanonicalJsonObject, + push::{InsertPushRuleError, PredefinedContentRuleId, PredefinedOverrideRuleId, RemovePushRuleError, Ruleset}, + CanonicalJsonObject, CanonicalJsonValue, }; use service::Services; @@ -27,44 +27,109 @@ pub(crate) async fn get_pushrules_all_route( ) -> Result<get_pushrules_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let global_ruleset: Ruleset; - - let Ok(event) = - services - .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into()) + let Some(content_value) = services + .account_data + .get_global::<CanonicalJsonObject>(sender_user, GlobalAccountDataEventType::PushRules) + .await + .ok() + .and_then(|event| event.get("content").cloned()) + .filter(CanonicalJsonValue::is_object) else { - // push rules event doesn't exist, create it and return default - return recreate_push_rules_and_return(&services, sender_user); + // user somehow has non-existent push rule event. recreate it and return server + // default silently + return recreate_push_rules_and_return(&services, sender_user).await; }; - if let Some(event) = event { - let value = serde_json::from_str::<CanonicalJsonObject>(event.get()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; - let Some(content_value) = value.get("content") else { - // user somehow has a push rule event with no content key, recreate it and - // return server default silently - return recreate_push_rules_and_return(&services, sender_user); - }; + let mut global_ruleset = account_data_content.global; - if content_value.to_string().is_empty() { - // user somehow has a push rule event with empty content, recreate it and return - // server default silently - return recreate_push_rules_and_return(&services, sender_user); - } + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + { + use ruma::push::RuleKind::*; + + global_ruleset + .remove(Override, PredefinedOverrideRuleId::ContainsDisplayName) + .ok(); + global_ruleset + .remove(Override, PredefinedOverrideRuleId::RoomNotif) + .ok(); + + global_ruleset + .remove(Content, PredefinedContentRuleId::ContainsUserName) + .ok(); + }; - let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.clone().into()) - .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + Ok(get_pushrules_all::v3::Response { + global: global_ruleset, + }) +} - global_ruleset = account_data_content.global; - } else { +/// # `GET /_matrix/client/r0/pushrules/global/` +/// +/// Retrieves the push rules event for this user. +/// +/// This appears to be the exact same as `GET /_matrix/client/r0/pushrules/`. +pub(crate) async fn get_pushrules_global_route( + State(services): State<crate::State>, body: Ruma<get_pushrules_global_scope::v3::Request>, +) -> Result<get_pushrules_global_scope::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let Some(content_value) = services + .account_data + .get_global::<CanonicalJsonObject>(sender_user, GlobalAccountDataEventType::PushRules) + .await + .ok() + .and_then(|event| event.get("content").cloned()) + .filter(CanonicalJsonValue::is_object) + else { // user somehow has non-existent push rule event. recreate it and return server // default silently - return recreate_push_rules_and_return(&services, sender_user); - } + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; + + return Ok(get_pushrules_global_scope::v3::Response { + global: Ruleset::server_default(sender_user), + }); + }; - Ok(get_pushrules_all::v3::Response { + let account_data_content = serde_json::from_value::<PushRulesEventContent>(content_value.into()) + .map_err(|e| err!(Database(warn!("Invalid push rules account data event in database: {e}"))))?; + + let mut global_ruleset = account_data_content.global; + + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + { + use ruma::push::RuleKind::*; + + global_ruleset + .remove(Override, PredefinedOverrideRuleId::ContainsDisplayName) + .ok(); + global_ruleset + .remove(Override, PredefinedOverrideRuleId::RoomNotif) + .ok(); + + global_ruleset + .remove(Content, PredefinedContentRuleId::ContainsUserName) + .ok(); + }; + + Ok(get_pushrules_global_scope::v3::Response { global: global_ruleset, }) } @@ -77,16 +142,23 @@ pub(crate) async fn get_pushrule_route( ) -> Result<get_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services - .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Err!(Request(NotFound("Push rule not found."))); + } - let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let event: PushRulesEvent = services + .account_data + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let rule = account_data + let rule = event + .content .global .get(body.kind.clone(), &body.rule_id) .map(Into::into); @@ -100,7 +172,7 @@ pub(crate) async fn get_pushrule_route( } } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}` /// /// Creates a single specified push rule for this user. pub(crate) async fn set_pushrule_route( @@ -109,20 +181,11 @@ pub(crate) async fn set_pushrule_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if let Err(error) = account_data @@ -155,17 +218,20 @@ pub(crate) async fn set_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule::v3::Response {}) } -/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` +/// # `GET /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/actions` /// /// Gets the actions of a single specified push rule for this user. pub(crate) async fn get_pushrule_actions_route( @@ -173,34 +239,34 @@ pub(crate) async fn get_pushrule_actions_route( ) -> Result<get_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Err!(Request(NotFound("Push rule not found."))); } - let event = services + let event: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let global = account_data.global; - let actions = global + let actions = event + .content + .global .get(body.kind.clone(), &body.rule_id) .map(|rule| rule.actions().to_owned()) - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; + .ok_or(err!(Request(NotFound("Push rule not found."))))?; Ok(get_pushrule_actions::v3::Response { actions, }) } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/actions` /// /// Sets the actions of a single specified push rule for this user. pub(crate) async fn set_pushrule_actions_route( @@ -208,20 +274,11 @@ pub(crate) async fn set_pushrule_actions_route( ) -> Result<set_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if account_data .content @@ -232,17 +289,20 @@ pub(crate) async fn set_pushrule_actions_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_actions::v3::Response {}) } -/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` +/// # `GET /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/enabled` /// /// Gets the enabled status of a single specified push rule for this user. pub(crate) async fn get_pushrule_enabled_route( @@ -250,33 +310,36 @@ pub(crate) async fn get_pushrule_enabled_route( ) -> Result<get_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); + // remove old deprecated mentions push rules as per MSC4210 + #[allow(deprecated)] + if body.rule_id.as_str() == PredefinedContentRuleId::ContainsUserName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::ContainsDisplayName.as_str() + || body.rule_id.as_str() == PredefinedOverrideRuleId::RoomNotif.as_str() + { + return Ok(get_pushrule_enabled::v3::Response { + enabled: false, + }); } - let event = services + let event: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; - let global = account_data.content.global; - let enabled = global + let enabled = event + .content + .global .get(body.kind.clone(), &body.rule_id) .map(ruma::push::AnyPushRuleRef::enabled) - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; + .ok_or(err!(Request(NotFound("Push rule not found."))))?; Ok(get_pushrule_enabled::v3::Response { enabled, }) } -/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` +/// # `PUT /_matrix/client/r0/pushrules/global/{kind}/{ruleId}/enabled` /// /// Sets the enabled status of a single specified push rule for this user. pub(crate) async fn set_pushrule_enabled_route( @@ -284,20 +347,11 @@ pub(crate) async fn set_pushrule_enabled_route( ) -> Result<set_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if account_data .content @@ -308,17 +362,20 @@ pub(crate) async fn set_pushrule_enabled_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(set_pushrule_enabled::v3::Response {}) } -/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` +/// # `DELETE /_matrix/client/r0/pushrules/global/{kind}/{ruleId}` /// /// Deletes a single specified push rule for this user. pub(crate) async fn delete_pushrule_route( @@ -326,20 +383,11 @@ pub(crate) async fn delete_pushrule_route( ) -> Result<delete_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } - - let event = services + let mut account_data: PushRulesEvent = services .account_data - .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - - let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + .get_global(sender_user, GlobalAccountDataEventType::PushRules) + .await + .map_err(|_| err!(Request(NotFound("PushRules event not found."))))?; if let Err(error) = account_data .content @@ -357,12 +405,15 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; Ok(delete_pushrule::v3::Response {}) } @@ -376,7 +427,7 @@ pub(crate) async fn get_pushers_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: services.pusher.get_pushers(sender_user)?, + pushers: services.pusher.get_pushers(sender_user).await, }) } @@ -390,27 +441,30 @@ pub(crate) async fn set_pushers_route( ) -> Result<set_pusher::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services.pusher.set_pusher(sender_user, &body.action)?; + services.pusher.set_pusher(sender_user, &body.action); Ok(set_pusher::v3::Response::default()) } /// user somehow has bad push rules, these must always exist per spec. /// so recreate it and return server default silently -fn recreate_push_rules_and_return( +async fn recreate_push_rules_and_return( services: &Services, sender_user: &ruma::UserId, ) -> Result<get_pushrules_all::v3::Response> { - services.account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(PushRulesEvent { - content: PushRulesEventContent { - global: Ruleset::server_default(sender_user), - }, - }) - .expect("to json always works"), - )?; + services + .account_data + .update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(PushRulesEvent { + content: PushRulesEventContent { + global: Ruleset::server_default(sender_user), + }, + }) + .expect("to json always works"), + ) + .await?; Ok(get_pushrules_all::v3::Response { global: Ruleset::server_default(sender_user), diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index f40f2493262a1dc11e99b59921edc2f03e44cc66..f28b2aec5352227eecd69e4b523f4c11788090eb 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -31,27 +31,32 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } if let Some(event) = &body.private_read_receipt { let count = services .rooms .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(event) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -64,7 +69,7 @@ pub(crate) async fn set_read_marker_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); } if let Some(event) = &body.read_receipt { @@ -83,14 +88,18 @@ pub(crate) async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; } Ok(set_read_marker::v3::Response {}) @@ -111,7 +120,7 @@ pub(crate) async fn create_receipt_route( services .rooms .user - .reset_notification_counts(sender_user, &body.room_id)?; + .reset_notification_counts(sender_user, &body.room_id); } match body.receipt_type { @@ -121,12 +130,15 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + ) + .await?; }, create_receipt::v3::ReceiptType::Read => { let mut user_receipts = BTreeMap::new(); @@ -143,21 +155,27 @@ pub(crate) async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.clone(), receipts); - services.rooms.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - &ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; + services + .rooms + .read_receipt + .readreceipt_update( + sender_user, + &body.room_id, + &ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + ) + .await; }, create_receipt::v3::ReceiptType::ReadPrivate => { let count = services .rooms .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .get_pdu_count(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { PduCount::Backfilled(_) => { return Err(Error::BadRequest( @@ -170,7 +188,7 @@ pub(crate) async fn create_receipt_route( services .rooms .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; + .private_read_set(&body.room_id, sender_user, count); }, _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 2102f6cd58a33533aaab28f7d3001dc389eadbb1..a986dc18bd509fe7aa08a85449e3252e3cd923c2 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -1,9 +1,5 @@ use axum::extract::State; -use ruma::{ - api::client::redact::redact_event, - events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, -}; -use serde_json::value::to_raw_value; +use ruma::{api::client::redact::redact_event, events::room::redaction::RoomRedactionEventContent}; use crate::{service::pdu::PduBuilder, Result, Ruma}; @@ -25,16 +21,11 @@ pub(crate) async fn redact_event_route( .timeline .build_and_append_pdu( PduBuilder { - event_type: TimelineEventType::RoomRedaction, - content: to_raw_value(&RoomRedactionEventContent { + redacts: Some(body.event_id.clone().into()), + ..PduBuilder::timeline(&RoomRedactionEventContent { redacts: Some(body.event_id.clone()), reason: body.reason.clone(), }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: Some(body.event_id.into()), - timestamp: None, }, sender_user, &body.room_id, @@ -44,8 +35,7 @@ pub(crate) async fn redact_event_route( drop(state_lock); - let event_id = (*event_id).to_owned(); Ok(redact_event::v3::Response { - event_id, + event_id: event_id.into(), }) } diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index ae64594000421c38c38a0a38af14286026913e2f..902e6be60cbc7b660995326e94b048a9a0ed4de3 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,30 +1,43 @@ use axum::extract::State; -use ruma::api::client::relations::{ - get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, +use conduit::{ + at, + utils::{result::FlatOk, IterStream, ReadyExt}, + PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{ + api::{ + client::relations::{ + get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, + }, + Direction, + }, + events::{relation::RelationType, TimelineEventType}, + EventId, RoomId, UInt, UserId, +}; +use service::{rooms::timeline::PdusIterItem, Services}; -use crate::{Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, + paginate_relations_with_filter( + &services, + body.sender_user(), &body.room_id, &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, body.recurse, body.dir, - )?; - - Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { + ) + .await + .map(|res| get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -36,22 +49,21 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>, ) -> Result<get_relating_events_with_rel_type::v1::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let res = services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, + paginate_relations_with_filter( + &services, + body.sender_user(), &body.room_id, &body.event_id, - &None, - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, + None, + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, body.recurse, body.dir, - )?; - - Ok(get_relating_events_with_rel_type::v1::Response { + ) + .await + .map(|res| get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -63,18 +75,103 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>, ) -> Result<get_relating_events::v1::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - services.rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, + paginate_relations_with_filter( + &services, + body.sender_user(), &body.room_id, &body.event_id, - &None, - &None, - &body.from, - &body.to, - &body.limit, + None, + None, + body.from.as_deref(), + body.to.as_deref(), + body.limit, body.recurse, body.dir, ) + .await +} + +#[allow(clippy::too_many_arguments)] +async fn paginate_relations_with_filter( + services: &Services, sender_user: &UserId, room_id: &RoomId, target: &EventId, + filter_event_type: Option<TimelineEventType>, filter_rel_type: Option<RelationType>, from: Option<&str>, + to: Option<&str>, limit: Option<UInt>, recurse: bool, dir: Direction, +) -> Result<get_relating_events::v1::Response> { + let start: PduCount = from + .map(str::parse) + .transpose()? + .unwrap_or_else(|| match dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); + + let to: Option<PduCount> = to.map(str::parse).flat_ok(); + + // Use limit or else 30, with maximum 100 + let limit: usize = limit + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(30) + .min(100); + + // Spec (v1.10) recommends depth of at least 3 + let depth: u8 = if recurse { + 3 + } else { + 1 + }; + + let events: Vec<PdusIterItem> = services + .rooms + .pdu_metadata + .get_relations(sender_user, room_id, target, start, limit, depth, dir) + .await + .into_iter() + .filter(|(_, pdu)| { + filter_event_type + .as_ref() + .is_none_or(|kind| *kind == pdu.kind) + }) + .filter(|(_, pdu)| { + filter_rel_type + .as_ref() + .is_none_or(|rel_type| pdu.relation_type_equal(rel_type)) + }) + .stream() + .filter_map(|item| visibility_filter(services, sender_user, item)) + .ready_take_while(|(count, _)| Some(*count) != to) + .take(limit) + .collect() + .boxed() + .await; + + let next_batch = match dir { + Direction::Forward => events.last(), + Direction::Backward => events.first(), + } + .map(at!(0)) + .as_ref() + .map(ToString::to_string); + + Ok(get_relating_events::v1::Response { + next_batch, + prev_batch: from.map(Into::into), + recursion_depth: recurse.then_some(depth.into()), + chunk: events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_message_like_event()) + .collect(), + }) +} + +async fn visibility_filter(services: &Services, sender_user: &UserId, item: PdusIterItem) -> Option<PdusIterItem> { + let (_, pdu) = &item; + + services + .rooms + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .await + .then_some(item) } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 588bd3686bc7959c655004ff70fb84ef776b1e28..a013370450579e887f04ade6e2cdc9273770ebd7 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,87 +1,130 @@ use std::time::Duration; use axum::extract::State; +use axum_client_ip::InsecureClientIp; +use conduit::{info, utils::ReadyExt, Err}; use rand::Rng; use ruma::{ - api::client::{error::ErrorKind, room::report_content}, + api::client::{ + error::ErrorKind, + room::{report_content, report_room}, + }, events::room::message, int, EventId, RoomId, UserId, }; use tokio::time::sleep; -use tracing::info; use crate::{ debug_info, service::{pdu::PduEvent, Services}, - utils::HtmlEscape, Error, Result, Ruma, }; +/// # `POST /_matrix/client/v3/rooms/{roomId}/report` +/// +/// Reports an abusive room to homeserver admins +#[tracing::instrument(skip_all, fields(%client), name = "report_room")] +pub(crate) async fn report_room_route( + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<report_room::v3::Request>, +) -> Result<report_room::v3::Response> { + // user authentication + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + info!( + "Received room report by user {sender_user} for room {} with reason: \"{}\"", + body.room_id, + body.reason.as_deref().unwrap_or("") + ); + + if body.reason.as_ref().is_some_and(|s| s.len() > 750) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", + )); + }; + + delay_response().await; + + if !services + .rooms + .state_cache + .server_in_room(&services.globals.config.server_name, &body.room_id) + .await + { + return Err!(Request(NotFound( + "Room does not exist to us, no local users have joined at all" + ))); + } + + // send admin room message that we received the report with an @room ping for + // urgency + services + .admin + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Room report received from {} -\n\nRoom ID: {}\n\nReport Reason: {}", + sender_user.to_owned(), + body.room_id, + body.reason.as_deref().unwrap_or("") + ))) + .await + .ok(); + + Ok(report_room::v3::Response {}) +} + /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// /// Reports an inappropriate event to homeserver admins +#[tracing::instrument(skip_all, fields(%client), name = "report_event")] pub(crate) async fn report_event_route( - State(services): State<crate::State>, body: Ruma<report_content::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<report_content::v3::Request>, ) -> Result<report_content::v3::Response> { // user authentication let sender_user = body.sender_user.as_ref().expect("user is authenticated"); info!( - "Received /report request by user {sender_user} for room {} and event ID {}", - body.room_id, body.event_id + "Received event report by user {sender_user} for room {} and event ID {}, with reason: \"{}\"", + body.room_id, + body.event_id, + body.reason.as_deref().unwrap_or("") ); delay_response().await; // check if we know about the reported event ID or if it's invalid - let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID is not known to us or Event ID is invalid", - )); + let Ok(pdu) = services.rooms.timeline.get_pdu(&body.event_id).await else { + return Err!(Request(NotFound("Event ID is not known to us or Event ID is invalid"))); }; - is_report_valid( + is_event_report_valid( &services, &pdu.event_id, &body.room_id, sender_user, - &body.reason, + body.reason.as_ref(), body.score, &pdu, - )?; + ) + .await?; // send admin room message that we received the report with an @room ping for // urgency services .admin - .send_message(message::RoomMessageEventContent::text_html( - format!( - "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ - Reason: {}", - sender_user.to_owned(), - pdu.event_id, - pdu.room_id, - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - body.reason.as_deref().unwrap_or("") - ), - format!( - "<details><summary>@room Report received from: <a href=\"https://matrix.to/#/{0}\">{0}\ - </a></summary><ul><li>Event Info<ul><li>Event ID: <code>{1}</code>\ - <a href=\"https://matrix.to/#/{2}/{1}\">🔗</a></li><li>Room ID: <code>{2}</code>\ - </li><li>Sent By: <a href=\"https://matrix.to/#/{3}\">{3}</a></li></ul></li><li>\ - Report Info<ul><li>Report Score: {4}</li><li>Report Reason: {5}</li></ul></li>\ - </ul></details>", - sender_user.to_owned(), - pdu.event_id.clone(), - pdu.room_id.clone(), - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - HtmlEscape(body.reason.as_deref().unwrap_or("")) - ), - )) - .await; + .send_message(message::RoomMessageEventContent::text_markdown(format!( + "@room Event report received from {} -\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: \ + {}\nReport Reason: {}", + sender_user.to_owned(), + pdu.event_id, + pdu.room_id, + pdu.sender, + body.score.unwrap_or_else(|| ruma::Int::from(0)), + body.reason.as_deref().unwrap_or("") + ))) + .await + .ok(); Ok(report_content::v3::Response {}) } @@ -92,8 +135,8 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters /// check if reporting user is in the reporting room -fn is_report_valid( - services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>, +async fn is_event_report_valid( + services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: Option<&String>, score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>, ) -> Result<()> { debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid"); @@ -123,8 +166,8 @@ fn is_report_valid( .rooms .state_cache .room_members(room_id) - .filter_map(Result::ok) - .any(|user_id| user_id == *sender_user) + .ready_any(|user_id| user_id == sender_user) + .await { return Err(Error::BadRequest( ErrorKind::NotFound, @@ -139,7 +182,7 @@ fn is_report_valid( /// random delay sending a response per spec suggestion regarding /// enumerating for potential events existing in our server. async fn delay_response() { - let time_to_wait = rand::thread_rng().gen_range(3..10); + let time_to_wait = rand::thread_rng().gen_range(2..5); debug_info!("Got successful /report request, waiting {time_to_wait} seconds before sending successful response."); sleep(Duration::from_secs(time_to_wait)).await; } diff --git a/src/api/client/room/aliases.rs b/src/api/client/room/aliases.rs new file mode 100644 index 0000000000000000000000000000000000000000..e530b26027aa1b55089803ba519428da9aa52b7e --- /dev/null +++ b/src/api/client/room/aliases.rs @@ -0,0 +1,40 @@ +use axum::extract::State; +use conduit::{Error, Result}; +use futures::StreamExt; +use ruma::api::client::{error::ErrorKind, room::aliases}; + +use crate::Ruma; + +/// # `GET /_matrix/client/r0/rooms/{roomId}/aliases` +/// +/// Lists all aliases of the room. +/// +/// - Only users joined to the room are allowed to call this, or if +/// `history_visibility` is world readable in the room +pub(crate) async fn get_room_aliases_route( + State(services): State<crate::State>, body: Ruma<aliases::v3::Request>, +) -> Result<aliases::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if !services + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id) + .await + { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "You don't have permission to view this room.", + )); + } + + Ok(aliases::v3::Response { + aliases: services + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .map(ToOwned::to_owned) + .collect() + .await, + }) +} diff --git a/src/api/client/room.rs b/src/api/client/room/create.rs similarity index 51% rename from src/api/client/room.rs rename to src/api/client/room/create.rs index 0112e76dcb475603b96067f87d7d75dba2b76d20..2ccb1c87af6d7da39dafd353c8b29ab4e3fbc7b3 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room/create.rs @@ -1,11 +1,12 @@ -use std::{cmp::max, collections::BTreeMap}; +use std::collections::BTreeMap; use axum::extract::State; -use conduit::{debug_info, debug_warn, err, Err}; +use conduit::{debug_info, debug_warn, error, info, pdu::PduBuilder, warn, Err, Error, Result}; +use futures::FutureExt; use ruma::{ api::client::{ error::ErrorKind, - room::{self, aliases, create_room, get_room_event, upgrade_room}, + room::{self, create_room}, }, events::{ room::{ @@ -17,36 +18,18 @@ member::{MembershipState, RoomMemberEventContent}, name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent, - tombstone::RoomTombstoneEventContent, topic::RoomTopicEventContent, }, - StateEventType, TimelineEventType, + TimelineEventType, }, int, serde::{JsonObject, Raw}, CanonicalJsonObject, Int, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomVersionId, }; use serde_json::{json, value::to_raw_value}; -use tracing::{error, info, warn}; +use service::{appservice::RegistrationInfo, Services}; -use super::invite_helper; -use crate::{ - service::{appservice::RegistrationInfo, pdu::PduBuilder, Services}, - Error, Result, Ruma, -}; - -/// Recommended transferable state events list from the spec -const TRANSFERABLE_STATE_EVENTS: &[StateEventType; 9] = &[ - StateEventType::RoomServerAcl, - StateEventType::RoomEncryption, - StateEventType::RoomName, - StateEventType::RoomAvatar, - StateEventType::RoomTopic, - StateEventType::RoomGuestAccess, - StateEventType::RoomHistoryVisibility, - StateEventType::RoomJoinRules, - StateEventType::RoomPowerLevels, -]; +use crate::{client::invite_helper, Ruma}; /// # `POST /_matrix/client/v3/createRoom` /// @@ -74,7 +57,7 @@ pub(crate) async fn create_room_route( if !services.globals.allow_room_creation() && body.appservice_info.is_none() - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); } @@ -86,7 +69,7 @@ pub(crate) async fn create_room_route( }; // check if room ID doesn't already exist instead of erroring on auth check - if services.rooms.short.get_shortroomid(&room_id)?.is_some() { + if services.rooms.short.get_shortroomid(&room_id).await.is_ok() { return Err(Error::BadRequest( ErrorKind::RoomInUse, "Room with that custom room ID already exists", @@ -95,7 +78,7 @@ pub(crate) async fn create_room_route( if body.visibility == room::Visibility::Public && services.globals.config.lockdown_public_room_directory - && !services.users.is_admin(sender_user)? + && !services.users.is_admin(sender_user).await && body.appservice_info.is_none() { info!( @@ -118,11 +101,15 @@ pub(crate) async fn create_room_route( return Err!(Request(Forbidden("Publishing rooms to the room directory is not allowed"))); } - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; - let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { - Some(room_alias_check(&services, alias, &body.appservice_info).await?) + let alias: Option<OwnedRoomAliasId> = if let Some(alias) = body.room_alias_name.as_ref() { + Some(room_alias_check(&services, alias, body.appservice_info.as_ref()).await?) } else { None }; @@ -145,8 +132,7 @@ pub(crate) async fn create_room_route( None => services.globals.default_room_version(), }; - #[allow(clippy::single_match_else)] - let content = match &body.creation_content { + let create_content = match &body.creation_content { Some(content) => { use RoomVersionId::*; @@ -208,16 +194,15 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, + content: to_raw_value(&create_content).expect("create event content serialization"), state_key: Some(String::new()), - redacts: None, - timestamp: None, + ..Default::default() }, sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 2. Let the room creator join @@ -225,28 +210,21 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, + PduBuilder::state( + sender_user.to_string(), + &RoomMemberEventContent { + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + blurhash: services.users.blurhash(sender_user).await.ok(), is_direct: Some(body.is_direct), - third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, + ..RoomMemberEventContent::new(MembershipState::Join) + }, + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 3. Power levels @@ -260,13 +238,21 @@ pub(crate) async fn create_room_route( let mut users = BTreeMap::from_iter([(sender_user.clone(), int!(100))]); if preset == RoomPreset::TrustedPrivateChat { - for invite_ in &body.invite { - users.insert(invite_.clone(), int!(100)); + for invite in &body.invite { + if services.users.user_is_ignored(sender_user, invite).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(invite, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + continue; + } + + users.insert(invite.clone(), int!(100)); } } let power_levels_content = - default_power_levels_content(&body.power_level_content_override, &body.visibility, users)?; + default_power_levels_content(body.power_level_content_override.as_ref(), &body.visibility, users)?; services .rooms @@ -274,16 +260,15 @@ pub(crate) async fn create_room_route( .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content).expect("to_raw_value always works on serde_json::Value"), - unsigned: None, + content: to_raw_value(&power_levels_content).expect("serialized power_levels event content"), state_key: Some(String::new()), - redacts: None, - timestamp: None, + ..Default::default() }, sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 4. Canonical room alias @@ -292,22 +277,18 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { + PduBuilder::state( + String::new(), + &RoomCanonicalAliasEventContent { alias: Some(room_alias_id.to_owned()), alt_aliases: vec![], - }) - .expect("We checked that alias earlier, it must be fine"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; } @@ -318,23 +299,19 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { + PduBuilder::state( + String::new(), + &RoomJoinRulesEventContent::new(match preset { RoomPreset::PublicChat => JoinRule::Public, // according to spec "invite" is the default _ => JoinRule::Invite, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }), + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 5.2 History Visibility @@ -342,19 +319,15 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared), + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 5.3 Guest Access @@ -362,22 +335,18 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { + PduBuilder::state( + String::new(), + &RoomGuestAccessEventContent::new(match preset { RoomPreset::PublicChat => GuestAccess::Forbidden, _ => GuestAccess::CanJoin, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }), + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; // 6. Events listed in initial_state @@ -410,6 +379,7 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .boxed() .await?; } @@ -419,19 +389,12 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(name.clone())) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomNameEventContent::new(name.clone())), sender_user, &room_id, &state_lock, ) + .boxed() .await?; } @@ -440,28 +403,35 @@ pub(crate) async fn create_room_route( .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { + PduBuilder::state( + String::new(), + &RoomTopicEventContent { topic: topic.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), sender_user, &room_id, &state_lock, ) + .boxed() .await?; } // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { + if services.users.user_is_ignored(sender_user, user_id).await { + return Err!(Request(Forbidden("You cannot invite users you have ignored to rooms."))); + } else if services.users.user_is_ignored(user_id, sender_user).await { + // silently drop the invite to the recipient if they've been ignored by the + // sender, pretend it worked + continue; + } + + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct) + .boxed() + .await + { warn!(%e, "Failed to send invite"); } } @@ -475,7 +445,7 @@ pub(crate) async fn create_room_route( } if body.visibility == room::Visibility::Public { - services.rooms.directory.set_public(&room_id)?; + services.rooms.directory.set_public(&room_id); if services.globals.config.admin_room_notices { services @@ -491,353 +461,9 @@ pub(crate) async fn create_room_route( Ok(create_room::v3::Response::new(room_id)) } -/// # `GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}` -/// -/// Gets a single event. -/// -/// - You have to currently be joined to the room (TODO: Respect history -/// visibility) -pub(crate) async fn get_room_event_route( - State(services): State<crate::State>, body: Ruma<get_room_event::v3::Request>, -) -> Result<get_room_event::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let event = services - .rooms - .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| err!(Request(NotFound("Event {} not found.", &body.event_id))))?; - - if !services - .rooms - .state_accessor - .user_can_see_event(sender_user, &event.room_id, &body.event_id)? - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this event.", - )); - } - - let mut event = (*event).clone(); - event.add_age()?; - - Ok(get_room_event::v3::Response { - event: event.to_room_event(), - }) -} - -/// # `GET /_matrix/client/r0/rooms/{roomId}/aliases` -/// -/// Lists all aliases of the room. -/// -/// - Only users joined to the room are allowed to call this, or if -/// `history_visibility` is world readable in the room -pub(crate) async fn get_room_aliases_route( - State(services): State<crate::State>, body: Ruma<aliases::v3::Request>, -) -> Result<aliases::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if !services - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - Ok(aliases::v3::Response { - aliases: services - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - .collect(), - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/upgrade` -/// -/// Upgrades the room. -/// -/// - Creates a replacement room -/// - Sends a tombstone event into the current room -/// - Sender user joins the room -/// - Transfers some state events -/// - Moves local aliases -/// - Modifies old room power levels to prevent users from speaking -pub(crate) async fn upgrade_room_route( - State(services): State<crate::State>, body: Ruma<upgrade_room::v3::Request>, -) -> Result<upgrade_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if !services - .globals - .supported_room_versions() - .contains(&body.new_version) - { - return Err(Error::BadRequest( - ErrorKind::UnsupportedRoomVersion, - "This server does not support that room version.", - )); - } - - // Create a replacement room - let replacement_room = RoomId::new(services.globals.server_name()); - - let _short_id = services - .rooms - .short - .get_or_create_shortroomid(&replacement_room)?; - - let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - - // Send a m.room.tombstone event to the old room to indicate that it is not - // intended to be used any further Fail if the sender does not have the required - // permissions - let tombstone_event_id = services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTombstone, - content: to_raw_value(&RoomTombstoneEventContent { - body: "This room has been replaced".to_owned(), - replacement_room: replacement_room.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - // Change lock to replacement room - drop(state_lock); - let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; - - // Get the old room creation event - let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; - - // Use the m.room.tombstone event as the predecessor - let predecessor = Some(ruma::events::room::create::PreviousRoom::new( - body.room_id.clone(), - (*tombstone_event_id).to_owned(), - )); - - // Send a m.room.create event containing a predecessor field and the applicable - // room_version - { - use RoomVersionId::*; - match body.new_version { - V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - create_event_content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Error forming creation event: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") - })?, - ); - }, - _ => { - // "creator" key no longer exists in V11+ rooms - create_event_content.remove("creator"); - }, - } - } - - create_event_content.insert( - "room_version".into(), - json!(&body.new_version) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); - create_event_content.insert( - "predecessor".into(), - json!(predecessor) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); - - // Validate creation event content - if serde_json::from_str::<CanonicalJsonObject>( - to_raw_value(&create_event_content) - .expect("Error forming creation event") - .get(), - ) - .is_err() - { - return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); - } - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&create_event_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - - // Join the new room - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services.users.displayname(sender_user)?, - avatar_url: services.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services.users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - timestamp: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - - // Replicate transferable state events to the new room - for event_type in TRANSFERABLE_STATE_EVENTS { - let event_content = match services - .rooms - .state_accessor - .room_state_get(&body.room_id, event_type, "")? - { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. - }; - - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: event_type.to_string().into(), - content: event_content, - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - } - - // Moves any local aliases to the new room - for alias in services - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .filter_map(Result::ok) - { - services - .rooms - .alias - .remove_alias(&alias, sender_user) - .await?; - services - .rooms - .alias - .set_alias(&alias, &replacement_room, sender_user)?; - } - - // Get the old room power levels - let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - services - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; - - // Setting events_default and invite to the greater of 50 and users_default + 1 - let new_level = max( - int!(50), - power_levels_event_content - .users_default - .checked_add(int!(1)) - .ok_or_else(|| { - Error::BadRequest(ErrorKind::BadJson, "users_default power levels event content is not valid") - })?, - ); - power_levels_event_content.events_default = new_level; - power_levels_event_content.invite = new_level; - - // Modify the power levels in the old room to prevent sending of events and - // inviting new users - services - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_event_content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; - - drop(state_lock); - - // Return the replacement room id - Ok(upgrade_room::v3::Response { - replacement_room, - }) -} - /// creates the power_levels_content for the PDU builder fn default_power_levels_content( - power_level_content_override: &Option<Raw<RoomPowerLevelsEventContent>>, visibility: &room::Visibility, + power_level_content_override: Option<&Raw<RoomPowerLevelsEventContent>>, visibility: &room::Visibility, users: BTreeMap<OwnedUserId, Int>, ) -> Result<serde_json::Value> { let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { @@ -887,7 +513,7 @@ fn default_power_levels_content( /// if a room is being created with a room alias, run our checks async fn room_alias_check( - services: &Services, room_alias_name: &str, appservice_info: &Option<RegistrationInfo>, + services: &Services, room_alias_name: &str, appservice_info: Option<&RegistrationInfo>, ) -> Result<OwnedRoomAliasId> { // Basic checks on the room alias validity if room_alias_name.contains(':') { @@ -921,13 +547,14 @@ async fn room_alias_check( if services .rooms .alias - .resolve_local_alias(&full_room_alias)? - .is_some() + .resolve_local_alias(&full_room_alias) + .await + .is_ok() { return Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")); } - if let Some(ref info) = appservice_info { + if let Some(info) = appservice_info { if !info.aliases.is_match(full_room_alias.as_str()) { return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); } diff --git a/src/api/client/room/event.rs b/src/api/client/room/event.rs new file mode 100644 index 0000000000000000000000000000000000000000..0f44f25d26dada02d36bf4d645b7f1aad4edf99a --- /dev/null +++ b/src/api/client/room/event.rs @@ -0,0 +1,38 @@ +use axum::extract::State; +use conduit::{err, Result}; +use futures::TryFutureExt; +use ruma::api::client::room::get_room_event; + +use crate::Ruma; + +/// # `GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}` +/// +/// Gets a single event. +/// +/// - You have to currently be joined to the room (TODO: Respect history +/// visibility) +pub(crate) async fn get_room_event_route( + State(services): State<crate::State>, ref body: Ruma<get_room_event::v3::Request>, +) -> Result<get_room_event::v3::Response> { + Ok(get_room_event::v3::Response { + event: services + .rooms + .timeline + .get_pdu_owned(&body.event_id) + .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id)))) + .and_then(|event| async move { + services + .rooms + .state_accessor + .user_can_see_event(body.sender_user(), &event.room_id, &body.event_id) + .await + .then_some(event) + .ok_or_else(|| err!(Request(Forbidden("You don't have permission to view this event.")))) + }) + .map_ok(|mut event| { + event.add_age().ok(); + event.to_room_event() + }) + .await?, + }) +} diff --git a/src/api/client/room/mod.rs b/src/api/client/room/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa2d168f0c6c17963411488dcacc20efbfe34b3c --- /dev/null +++ b/src/api/client/room/mod.rs @@ -0,0 +1,9 @@ +mod aliases; +mod create; +mod event; +mod upgrade; + +pub(crate) use self::{ + aliases::get_room_aliases_route, create::create_room_route, event::get_room_event_route, + upgrade::upgrade_room_route, +}; diff --git a/src/api/client/room/upgrade.rs b/src/api/client/room/upgrade.rs new file mode 100644 index 0000000000000000000000000000000000000000..ad5c356e81ce39368ed957844904391177c6bcc3 --- /dev/null +++ b/src/api/client/room/upgrade.rs @@ -0,0 +1,294 @@ +use std::cmp::max; + +use axum::extract::State; +use conduit::{err, info, pdu::PduBuilder, Error, Result}; +use futures::StreamExt; +use ruma::{ + api::client::{error::ErrorKind, room::upgrade_room}, + events::{ + room::{ + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + tombstone::RoomTombstoneEventContent, + }, + StateEventType, TimelineEventType, + }, + int, CanonicalJsonObject, RoomId, RoomVersionId, +}; +use serde_json::{json, value::to_raw_value}; + +use crate::Ruma; + +/// Recommended transferable state events list from the spec +const TRANSFERABLE_STATE_EVENTS: &[StateEventType; 9] = &[ + StateEventType::RoomServerAcl, + StateEventType::RoomEncryption, + StateEventType::RoomName, + StateEventType::RoomAvatar, + StateEventType::RoomTopic, + StateEventType::RoomGuestAccess, + StateEventType::RoomHistoryVisibility, + StateEventType::RoomJoinRules, + StateEventType::RoomPowerLevels, +]; + +/// # `POST /_matrix/client/r0/rooms/{roomId}/upgrade` +/// +/// Upgrades the room. +/// +/// - Creates a replacement room +/// - Sends a tombstone event into the current room +/// - Sender user joins the room +/// - Transfers some state events +/// - Moves local aliases +/// - Modifies old room power levels to prevent users from speaking +pub(crate) async fn upgrade_room_route( + State(services): State<crate::State>, body: Ruma<upgrade_room::v3::Request>, +) -> Result<upgrade_room::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if !services + .globals + .supported_room_versions() + .contains(&body.new_version) + { + return Err(Error::BadRequest( + ErrorKind::UnsupportedRoomVersion, + "This server does not support that room version.", + )); + } + + // Create a replacement room + let replacement_room = RoomId::new(services.globals.server_name()); + + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&replacement_room) + .await; + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + // Send a m.room.tombstone event to the old room to indicate that it is not + // intended to be used any further Fail if the sender does not have the required + // permissions + let tombstone_event_id = services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state( + String::new(), + &RoomTombstoneEventContent { + body: "This room has been replaced".to_owned(), + replacement_room: replacement_room.clone(), + }, + ), + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + // Change lock to replacement room + drop(state_lock); + let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; + + // Get the old room creation event + let mut create_event_content: CanonicalJsonObject = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomCreate, "") + .await + .map_err(|_| err!(Database("Found room without m.room.create event.")))?; + + // Use the m.room.tombstone event as the predecessor + let predecessor = Some(ruma::events::room::create::PreviousRoom::new( + body.room_id.clone(), + (*tombstone_event_id).to_owned(), + )); + + // Send a m.room.create event containing a predecessor field and the applicable + // room_version + { + use RoomVersionId::*; + match body.new_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + create_event_content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Error forming creation event: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + })?, + ); + }, + _ => { + // "creator" key no longer exists in V11+ rooms + create_event_content.remove("creator"); + }, + } + } + + create_event_content.insert( + "room_version".into(), + json!(&body.new_version) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); + create_event_content.insert( + "predecessor".into(), + json!(predecessor) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); + + // Validate creation event content + if serde_json::from_str::<CanonicalJsonObject>( + to_raw_value(&create_event_content) + .expect("Error forming creation event") + .get(), + ) + .is_err() + { + return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); + } + + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCreate, + content: to_raw_value(&create_event_content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(String::new()), + redacts: None, + timestamp: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + + // Join the new room + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services.users.displayname(sender_user).await.ok(), + avatar_url: services.users.avatar_url(sender_user).await.ok(), + is_direct: None, + third_party_invite: None, + blurhash: services.users.blurhash(sender_user).await.ok(), + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + timestamp: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + + // Replicate transferable state events to the new room + for event_type in TRANSFERABLE_STATE_EVENTS { + let event_content = match services + .rooms + .state_accessor + .room_state_get(&body.room_id, event_type, "") + .await + { + Ok(v) => v.content.clone(), + Err(_) => continue, // Skipping missing events. + }; + + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: event_type.to_string().into(), + content: event_content, + state_key: Some(String::new()), + ..Default::default() + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + } + + // Moves any local aliases to the new room + let mut local_aliases = services + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .boxed(); + + while let Some(alias) = local_aliases.next().await { + services + .rooms + .alias + .remove_alias(alias, sender_user) + .await?; + + services + .rooms + .alias + .set_alias(alias, &replacement_room, sender_user)?; + } + + // Get the old room power levels + let power_levels_event_content: RoomPowerLevelsEventContent = services + .rooms + .state_accessor + .room_state_get_content(&body.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("Found room without m.room.power_levels event.")))?; + + // Setting events_default and invite to the greater of 50 and users_default + 1 + let new_level = max( + int!(50), + power_levels_event_content + .users_default + .checked_add(int!(1)) + .ok_or_else(|| err!(Request(BadJson("users_default power levels event content is not valid"))))?, + ); + + // Modify the power levels in the old room to prevent sending of events and + // inviting new users + services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { + events_default: new_level, + invite: new_level, + ..power_levels_event_content + }, + ), + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + drop(state_lock); + + // Return the replacement room id + Ok(upgrade_room::v3::Response { + replacement_room, + }) +} diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b143bd2c7f1eb9cb7d9cdd1cde4a752e9b97f595..1e5384fe26c1c7c7b4b262080d31875a9deb8a55 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,21 +1,33 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{ + at, is_true, + result::FlatOk, + utils::{stream::ReadyExt, IterStream}, + Err, PduEvent, Result, +}; +use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - error::ErrorKind, - search::search_events::{ - self, - v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, - }, + api::client::search::search_events::{ + self, + v3::{Criteria, EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, }, events::AnyStateEvent, serde::Raw, - uint, OwnedRoomId, + OwnedRoomId, RoomId, UInt, UserId, }; -use tracing::debug; +use search_events::v3::{Request, Response}; +use service::{rooms::search::RoomQuery, Services}; + +use crate::Ruma; -use crate::{Error, Result, Ruma}; +type RoomStates = BTreeMap<OwnedRoomId, RoomState>; +type RoomState = Vec<Raw<AnyStateEvent>>; + +const LIMIT_DEFAULT: usize = 10; +const LIMIT_MAX: usize = 100; +const BATCH_MAX: usize = 20; /// # `POST /_matrix/client/r0/search` /// @@ -23,160 +35,177 @@ /// /// - Only works if the user is currently joined to the room (TODO: Respect /// history visibility) -pub(crate) async fn search_events_route( - State(services): State<crate::State>, body: Ruma<search_events::v3::Request>, -) -> Result<search_events::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let search_criteria = body.search_categories.room_events.as_ref().unwrap(); - let filter = &search_criteria.filter; - let include_state = &search_criteria.include_state; +pub(crate) async fn search_events_route(State(services): State<crate::State>, body: Ruma<Request>) -> Result<Response> { + let sender_user = body.sender_user(); + let next_batch = body.next_batch.as_deref(); + let room_events_result: OptionFuture<_> = body + .search_categories + .room_events + .as_ref() + .map(|criteria| category_room_events(&services, sender_user, next_batch, criteria)) + .into(); + + Ok(Response { + search_categories: ResultCategories { + room_events: room_events_result + .await + .unwrap_or_else(|| Ok(ResultRoomEvents::default()))?, + }, + }) +} - let room_ids = filter.rooms.clone().unwrap_or_else(|| { - services - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(Result::ok) - .collect() - }); +#[allow(clippy::map_unwrap_or)] +async fn category_room_events( + services: &Services, sender_user: &UserId, next_batch: Option<&str>, criteria: &Criteria, +) -> Result<ResultRoomEvents> { + let filter = &criteria.filter; - // Use limit or else 10, with maximum 100 let limit: usize = filter .limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) - .min(100); - - let mut room_states: BTreeMap<OwnedRoomId, Vec<Raw<AnyStateEvent>>> = BTreeMap::new(); - - if include_state.is_some_and(|include_state| include_state) { - for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - // check if sender_user can see state events - if services + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); + + let next_batch: usize = next_batch + .map(str::parse) + .transpose()? + .unwrap_or(0) + .min(limit.saturating_mul(BATCH_MAX)); + + let rooms = filter + .rooms + .clone() + .map(IntoIterator::into_iter) + .map(IterStream::stream) + .map(StreamExt::boxed) + .unwrap_or_else(|| { + services .rooms - .state_accessor - .user_can_see_state_events(sender_user, room_id)? - { - let room_state = services - .rooms - .state_accessor - .room_state_full(room_id) - .await? - .values() - .map(|pdu| pdu.to_state_event()) - .collect::<Vec<_>>(); - - debug!("Room state: {:?}", room_state); - - room_states.insert(room_id.clone(), room_state); - } else { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - } - } - - let mut searches = Vec::new(); - - for room_id in &room_ids { - if !services.rooms.state_cache.is_joined(sender_user, room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - if let Some(search) = services - .rooms - .search - .search_pdus(room_id, &search_criteria.search_term)? - { - searches.push(search.0.peekable()); - } - } + .state_cache + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .boxed() + }); + + let results: Vec<_> = rooms + .filter_map(|room_id| async move { + check_room_visible(services, sender_user, &room_id, criteria) + .await + .is_ok() + .then_some(room_id) + }) + .filter_map(|room_id| async move { + let query = RoomQuery { + room_id: &room_id, + user_id: Some(sender_user), + criteria, + skip: next_batch, + limit, + }; + + let (count, results) = services.rooms.search.search_pdus(&query).await.ok()?; + + results + .collect::<Vec<_>>() + .map(|results| (room_id.clone(), count, results)) + .map(Some) + .await + }) + .collect() + .await; - let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { - Some(Ok(s)) => s, - Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), - None => 0, // Default to the start - }; - - let mut results = Vec::new(); - let next_batch = skip.saturating_add(limit); - - for _ in 0..next_batch { - if let Some(s) = searches - .iter_mut() - .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) - .and_then(|(_, i)| i.next()) - { - results.push(s); - } - } + let total: UInt = results + .iter() + .fold(0, |a: usize, (_, count, _)| a.saturating_add(*count)) + .try_into()?; - let results: Vec<_> = results + let state: RoomStates = results .iter() - .skip(skip) - .filter_map(|result| { - services - .rooms - .timeline - .get_pdu_from_id(result) - .ok()? - .filter(|pdu| { - !pdu.is_redacted() - && services - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .map(|pdu| pdu.to_room_event()) + .stream() + .ready_filter(|_| criteria.include_state.is_some_and(is_true!())) + .filter_map(|(room_id, ..)| async move { + procure_room_state(services, room_id) + .map_ok(|state| (room_id.clone(), state)) + .await + .ok() }) - .map(|result| { - Ok::<_, Error>(SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), - }) + .collect() + .await; + + let results: Vec<SearchResult> = results + .into_iter() + .map(at!(2)) + .flatten() + .stream() + .map(|pdu| pdu.to_room_event()) + .map(|result| SearchResult { + rank: None, + result: Some(result), + context: EventContextResult { + profile_info: BTreeMap::new(), //TODO + events_after: Vec::new(), //TODO + events_before: Vec::new(), //TODO + start: None, //TODO + end: None, //TODO + }, }) - .filter_map(Result::ok) - .take(limit) + .collect() + .await; + + let highlights = criteria + .search_term + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) .collect(); - let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); - let next_batch = more_unloaded_results.then(|| next_batch.to_string()); - - Ok(search_events::v3::Response::new(ResultCategories { - room_events: ResultRoomEvents { - count: Some(results.len().try_into().unwrap_or_else(|_| uint!(0))), - groups: BTreeMap::new(), // TODO - next_batch, - results, - state: room_states, - highlights: search_criteria - .search_term - .split_terminator(|c: char| !c.is_alphanumeric()) - .map(str::to_lowercase) - .collect(), - }, - })) + let next_batch = (results.len() >= limit) + .then_some(next_batch.saturating_add(results.len())) + .as_ref() + .map(ToString::to_string); + + Ok(ResultRoomEvents { + count: Some(total), + next_batch, + results, + state, + highlights, + groups: BTreeMap::new(), // TODO + }) +} + +async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result<RoomState> { + let state_map = services + .rooms + .state_accessor + .room_state_full(room_id) + .await?; + + let state_events = state_map + .values() + .map(AsRef::as_ref) + .map(PduEvent::to_state_event) + .collect(); + + Ok(state_events) +} + +async fn check_room_visible(services: &Services, user_id: &UserId, room_id: &RoomId, search: &Criteria) -> Result { + let check_visible = search.filter.rooms.is_some(); + let check_state = check_visible && search.include_state.is_some_and(is_true!()); + + let is_joined = !check_visible || services.rooms.state_cache.is_joined(user_id, room_id).await; + + let state_visible = !check_state + || services + .rooms + .state_accessor + .user_can_see_state_events(user_id, room_id) + .await; + + if !is_joined || !state_visible { + return Err!(Request(Forbidden("You don't have permission to view {room_id:?}"))); + } + + Ok(()) } diff --git a/src/api/client/send.rs b/src/api/client/send.rs new file mode 100644 index 0000000000000000000000000000000000000000..ff011efabc75ada20d4fe363d59128142af50d67 --- /dev/null +++ b/src/api/client/send.rs @@ -0,0 +1,92 @@ +use std::collections::BTreeMap; + +use axum::extract::State; +use conduit::{err, Err}; +use ruma::{api::client::message::send_message_event, events::MessageLikeEventType}; +use serde_json::from_str; + +use crate::{service::pdu::PduBuilder, utils, Result, Ruma}; + +/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` +/// +/// Send a message event into the room. +/// +/// - Is a NOOP if the txn id was already used before and returns the same event +/// id again +/// - The only requirement for the content is that it has to be valid json +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed +pub(crate) async fn send_message_event_route( + State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>, +) -> Result<send_message_event::v3::Response> { + let sender_user = body.sender_user(); + let sender_device = body.sender_device.as_deref(); + let appservice_info = body.appservice_info.as_ref(); + + // Forbid m.room.encrypted if encryption is disabled + if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { + return Err!(Request(Forbidden("Encryption has been disabled"))); + } + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + if body.event_type == MessageLikeEventType::CallInvite + && services.rooms.directory.is_public_room(&body.room_id).await + { + return Err!(Request(Forbidden("Room call invites are not allowed in public rooms"))); + } + + // Check if this is a new transaction id + if let Ok(response) = services + .transaction_ids + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + { + // The client might have sent a txnid of the /sendToDevice endpoint + // This txnid has no response associated with it + if response.is_empty() { + return Err!(Request(InvalidParam( + "Tried to use txn id already used for an incompatible endpoint." + ))); + } + + return Ok(send_message_event::v3::Response { + event_id: utils::string_from_bytes(&response) + .map(TryInto::try_into) + .map_err(|e| err!(Database("Invalid event_id in txnid data: {e:?}")))??, + }); + } + + let mut unsigned = BTreeMap::new(); + unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + + let content = + from_str(body.body.body.json().get()).map_err(|e| err!(Request(BadJson("Invalid JSON body: {e}"))))?; + + let event_id = services + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: body.event_type.clone().into(), + content, + unsigned: Some(unsigned), + timestamp: appservice_info.and(body.timestamp), + ..Default::default() + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; + + services + .transaction_ids + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes()); + + drop(state_lock); + + Ok(send_message_event::v3::Response { + event_id: event_id.into(), + }) +} diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 4702b0ec142e8242fab50a625f8654f39df036e6..6347a2c950679639b2bfcc1f988460f97243979a 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -1,5 +1,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; +use conduit::{debug, err, info, utils::ReadyExt, warn, Err}; +use futures::StreamExt; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +21,6 @@ UserId, }; use serde::Deserialize; -use tracing::{debug, info, warn}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, utils::hash, Error, Result, Ruma}; @@ -79,21 +80,22 @@ pub(crate) async fn login_route( UserId::parse(user) } else { warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Bad login type.")); + return Err!(Request(Forbidden("Bad login type."))); } .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; let hash = services .users - .password_hash(&user_id)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; + .password_hash(&user_id) + .await + .map_err(|_| err!(Request(Forbidden("Wrong username or password."))))?; if hash.is_empty() { - return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); + return Err!(Request(UserDeactivated("The user has been deactivated"))); } if hash::verify_password(password, &hash).is_err() { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password.")); + return Err!(Request(Forbidden("Wrong username or password."))); } user_id @@ -112,15 +114,12 @@ pub(crate) async fn login_route( let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { - warn!("Failed to parse username from user logging in: {e}"); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })? + UserId::parse_with_server_name(username, services.globals.server_name()) + .map_err(|e| err!(Request(InvalidUsername(debug_error!(?e, "Failed to parse login username")))))? } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); + return Err!(Request(Unknown( + "Token login is not supported (server has no jwt decoding key)." + ))); } }, #[allow(deprecated)] @@ -169,23 +168,32 @@ pub(crate) async fn login_route( let token = utils::random_string(TOKEN_LENGTH); // Determine if device_id was provided and exists in the db for this user - let device_exists = body.device_id.as_ref().map_or(false, |device_id| { + let device_exists = if body.device_id.is_some() { services .users .all_device_ids(&user_id) - .any(|x| x.as_ref().map_or(false, |v| v == device_id)) - }); + .ready_any(|v| v == device_id) + .await + } else { + false + }; if device_exists { - services.users.set_token(&user_id, &device_id, &token)?; + services + .users + .set_token(&user_id, &device_id, &token) + .await?; } else { - services.users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - Some(client.to_string()), - )?; + services + .users + .create_device( + &user_id, + &device_id, + &token, + body.initial_device_display_name.clone(), + Some(client.to_string()), + ) + .await?; } // send client well-known if specified so the client knows to reconfigure itself @@ -228,10 +236,13 @@ pub(crate) async fn logout_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services.users.remove_device(sender_user, sender_device)?; + services + .users + .remove_device(sender_user, sender_device) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout::v3::Response::new()) } @@ -256,12 +267,14 @@ pub(crate) async fn logout_all_route( ) -> Result<logout_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services.users.all_device_ids(sender_user).flatten() { - services.users.remove_device(sender_user, &device_id)?; - } + services + .users + .all_device_ids(sender_user) + .for_each(|device_id| services.users.remove_device(sender_user, device_id)) + .await; // send device list update for user after logout - services.users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user).await; Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index fd0496639fd97336268e5ac723d120992dc6aa56..5090d55755fae6a5f8db4a24c437b30c10af50ea 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; +use conduit::{err, pdu::PduBuilder, utils::BoolExt, Err, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -84,12 +84,10 @@ pub(crate) async fn get_state_events_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } Ok(get_state_events::v3::Response { @@ -120,43 +118,34 @@ pub(crate) async fn get_state_events_for_key_route( if !services .rooms .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? + .user_can_see_state_events(sender_user, &body.room_id) + .await { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view the room state.", - )); + return Err!(Request(Forbidden("You don't have permission to view the room state."))); } let event = services .rooms .state_accessor - .room_state_get(&body.room_id, &body.event_type, &body.state_key)? - .ok_or_else(|| { - debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") + .room_state_get(&body.room_id, &body.event_type, &body.state_key) + .await + .map_err(|_| { + err!(Request(NotFound(debug_warn!( + room_id = ?body.room_id, + event_type = ?body.event_type, + "State event not found in room.", + )))) })?; - if body + + let event_format = body .format .as_ref() - .is_some_and(|f| f.to_lowercase().eq("event")) - { - Ok(get_state_events_for_key::v3::Response { - content: None, - event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { - error!("Invalid room state event in database: {}", e); - Error::bad_database("Invalid room state event in database") - })?, - }) - } else { - Ok(get_state_events_for_key::v3::Response { - content: Some(serde_json::from_str(event.content.get()).map_err(|e| { - error!("Invalid room state event content in database: {}", e); - Error::bad_database("Invalid room state event content in database") - })?), - event: None, - }) - } + .is_some_and(|f| f.to_lowercase().eq("event")); + + Ok(get_state_events_for_key::v3::Response { + content: event_format.or(|| event.get_content_as_value()), + event: event_format.then(|| event.to_state_event_value()), + }) } /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}` @@ -187,11 +176,10 @@ async fn send_state_event_for_key_helper( .build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), - content: serde_json::from_str(json.json().get()).expect("content is valid json"), - unsigned: None, + content: serde_json::from_str(json.json().get())?, state_key: Some(state_key), - redacts: None, timestamp, + ..Default::default() }, sender, room_id, @@ -204,7 +192,7 @@ async fn send_state_event_for_key_helper( async fn allowed_to_send_state_event( services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, -) -> Result<()> { +) -> Result { match event_type { // Forbid m.room.encryption if encryption is disabled StateEventType::RoomEncryption => { @@ -214,7 +202,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made public StateEventType::RoomJoinRules => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(join_rule) = serde_json::from_str::<RoomJoinRulesEventContent>(json.json().get()) { if join_rule.join_rule == JoinRule::Public { @@ -229,7 +217,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made world readable StateEventType::RoomHistoryVisibility => { - if let Some(admin_room_id) = services.admin.get_admin_room()? { + if let Ok(admin_room_id) = services.admin.get_admin_room().await { if admin_room_id == room_id { if let Ok(visibility_content) = serde_json::from_str::<RoomHistoryVisibilityEventContent>(json.json().get()) @@ -254,23 +242,27 @@ async fn allowed_to_send_state_event( } for alias in aliases { - if !services.globals.server_is_ours(alias.server_name()) - || services - .rooms - .alias - .resolve_local_alias(&alias)? - .filter(|room| room == room_id) // Make sure it's the right room - .is_none() + if !services.globals.server_is_ours(alias.server_name()) { + return Err!(Request(Forbidden("canonical_alias must be for this server"))); + } + + if !services + .rooms + .alias + .resolve_local_alias(&alias) + .await + .is_ok_and(|room| room == room_id) + // Make sure it's the right room { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You are only allowed to send canonical_alias events when its aliases already exist", - )); + return Err!(Request(Forbidden( + "You are only allowed to send canonical_alias events when its aliases already exist" + ))); } } } }, _ => (), } + Ok(()) } diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs deleted file mode 100644 index eb534205e0199139727ea5d79f02e5f33cfba965..0000000000000000000000000000000000000000 --- a/src/api/client/sync.rs +++ /dev/null @@ -1,1783 +0,0 @@ -use std::{ - cmp::{self, Ordering}, - collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, - time::Duration, -}; - -use axum::extract::State; -use conduit::{ - debug, error, - utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - warn, Err, PduCount, -}; -use ruma::{ - api::client::{ - error::ErrorKind, - filter::{FilterDefinition, LazyLoadOptions}, - sync::sync_events::{ - self, - v3::{ - Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, - RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, - }, - v4::{SlidingOp, SlidingSyncRoomHero}, - DeviceLists, UnreadNotificationsCount, - }, - uiaa::UiaaResponse, - }, - directory::RoomTypeFilter, - events::{ - presence::PresenceEvent, - room::member::{MembershipState, RoomMemberEventContent}, - AnyRawAccountDataEvent, StateEventType, TimelineEventType, - }, - serde::Raw, - state_res::Event, - uint, DeviceId, EventId, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId, -}; -use service::rooms::read_receipt::pack_receipts; -use tracing::{Instrument as _, Span}; - -use crate::{ - service::{pdu::EventHash, Services}, - utils, Error, PduEvent, Result, Ruma, RumaResponse, -}; - -const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync"; -const DEFAULT_BUMP_TYPES: &[TimelineEventType] = &[ - TimelineEventType::RoomMessage, - TimelineEventType::RoomEncrypted, - TimelineEventType::Sticker, - TimelineEventType::CallInvite, - TimelineEventType::PollStart, - TimelineEventType::Beacon, -]; - -macro_rules! extract_variant { - ($e:expr, $variant:path) => { - match $e { - $variant(value) => Some(value), - _ => None, - } - }; -} - -/// # `GET /_matrix/client/r0/sync` -/// -/// Synchronize the client's state with the latest state on the server. -/// -/// - This endpoint takes a `since` parameter which should be the `next_batch` -/// value from a previous request for incremental syncs. -/// -/// Calling this endpoint without a `since` parameter returns: -/// - Some of the most recent events of each timeline -/// - Notification counts for each room -/// - Joined and invited member counts, heroes -/// - All state events -/// -/// Calling this endpoint with a `since` parameter from a previous `next_batch` -/// returns: For joined rooms: -/// - Some of the most recent events of each timeline that happened after since -/// - If user joined the room after since: All state events (unless lazy loading -/// is activated) and all device list updates in that room -/// - If the user was already in the room: A list of all events that are in the -/// state now, but were not in the state at `since` -/// - If the state we send contains a member event: Joined and invited member -/// counts, heroes -/// - Device list updates that happened after `since` -/// - If there are events in the timeline we send or the user send updated his -/// read mark: Notification counts -/// - EDUs that are active now (read receipts, typing updates, presence) -/// - TODO: Allow multiple sync streams to support Pantalaimon -/// -/// For invited rooms: -/// - If the user was invited after `since`: A subset of the state of the room -/// at the point of the invite -/// -/// For left rooms: -/// - If the user left after `since`: `prev_batch` token, empty state (TODO: -/// subset of the state at the point of the leave) -pub(crate) async fn sync_events_route( - State(services): State<crate::State>, body: Ruma<sync_events::v3::Request>, -) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { - let sender_user = body.sender_user.expect("user is authenticated"); - let sender_device = body.sender_device.expect("user is authenticated"); - let body = body.body; - - // Presence update - if services.globals.allow_local_presence() { - services - .presence - .ping_presence(&sender_user, &body.set_presence)?; - } - - // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(&sender_user, &sender_device); - - let next_batch = services.globals.current_count()?; - let next_batchcount = PduCount::Normal(next_batch); - let next_batch_string = next_batch.to_string(); - - // Load filter - let filter = match body.filter { - None => FilterDefinition::default(), - Some(Filter::FilterDefinition(filter)) => filter, - Some(Filter::FilterId(filter_id)) => services - .users - .get_filter(&sender_user, &filter_id)? - .unwrap_or_default(), - }; - - // some clients, at least element, seem to require knowledge of redundant - // members for "inline" profiles on the timeline to work properly - let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members, - } => (true, include_redundant_members), - LazyLoadOptions::Disabled => (false, cfg!(feature = "element_hacks")), - }; - - let full_state = body.full_state; - - let mut joined_rooms = BTreeMap::new(); - let since = body - .since - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); - let sincecount = PduCount::Normal(since); - - let mut presence_updates = HashMap::new(); - let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in - let mut device_list_updates = HashSet::new(); - let mut device_list_left = HashSet::new(); - - // Look for device list updates of this account - device_list_updates.extend( - services - .users - .keys_changed(sender_user.as_ref(), since, None) - .filter_map(Result::ok), - ); - - if services.globals.allow_local_presence() { - process_presence_updates(&services, &mut presence_updates, since, &sender_user).await?; - } - - let all_joined_rooms = services - .rooms - .state_cache - .rooms_joined(&sender_user) - .collect::<Vec<_>>(); - - // Coalesce database writes for the remainder of this scope. - let _cork = services.db.cork_and_flush(); - - for room_id in all_joined_rooms { - let room_id = room_id?; - if let Ok(joined_room) = load_joined_room( - &services, - &sender_user, - &sender_device, - &room_id, - since, - sincecount, - next_batch, - next_batchcount, - lazy_load_enabled, - lazy_load_send_redundant, - full_state, - &mut device_list_updates, - &mut left_encrypted_users, - ) - .await - { - if !joined_room.is_empty() { - joined_rooms.insert(room_id.clone(), joined_room); - } - } - } - - let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services - .rooms - .state_cache - .rooms_left(&sender_user) - .collect(); - for result in all_left_rooms { - handle_left_room( - &services, - since, - &result?.0, - &sender_user, - &mut left_rooms, - &next_batch_string, - full_state, - lazy_load_enabled, - ) - .instrument(Span::current()) - .await?; - } - - let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = services - .rooms - .state_cache - .rooms_invited(&sender_user) - .collect(); - for result in all_invited_rooms { - let (room_id, invite_state_events) = result?; - - // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; - drop(insert_lock); - - let invite_count = services - .rooms - .state_cache - .get_invite_count(&room_id, &sender_user)?; - - // Invited before last sync - if Some(since) >= invite_count { - continue; - } - - invited_rooms.insert( - room_id.clone(), - InvitedRoom { - invite_state: InviteState { - events: invite_state_events, - }, - }, - ); - } - - for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need - // to tell them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } - - // Remove all to-device events the device received *last time* - services - .users - .remove_to_device_events(&sender_user, &sender_device, since)?; - - let response = sync_events::v3::Response { - next_batch: next_batch_string, - rooms: Rooms { - leave: left_rooms, - join: joined_rooms, - invite: invited_rooms, - knock: BTreeMap::new(), // TODO - }, - presence: Presence { - events: presence_updates - .into_values() - .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) - .collect(), - }, - account_data: GlobalAccountData { - events: services - .account_data - .changes_since(None, &sender_user, since)? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) - .collect(), - }, - device_lists: DeviceLists { - changed: device_list_updates.into_iter().collect(), - left: device_list_left.into_iter().collect(), - }, - device_one_time_keys_count: services - .users - .count_one_time_keys(&sender_user, &sender_device)?, - to_device: ToDevice { - events: services - .users - .get_to_device_events(&sender_user, &sender_device)?, - }, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, - }; - - // TODO: Retry the endpoint instead of returning - if !full_state - && response.rooms.is_empty() - && response.presence.is_empty() - && response.account_data.is_empty() - && response.device_lists.is_empty() - && response.to_device.is_empty() - { - // Hang a few seconds so requests are not spammed - // Stop hanging if new info arrives - let default = Duration::from_secs(30); - let duration = cmp::min(body.timeout.unwrap_or(default), default); - _ = tokio::time::timeout(duration, watcher).await; - } - - Ok(response) -} - -#[allow(clippy::too_many_arguments)] -#[tracing::instrument(skip_all, fields(user_id = %sender_user, room_id = %room_id), name = "left_room")] -async fn handle_left_room( - services: &Services, since: u64, room_id: &RoomId, sender_user: &UserId, - left_rooms: &mut BTreeMap<OwnedRoomId, LeftRoom>, next_batch_string: &str, full_state: bool, - lazy_load_enabled: bool, -) -> Result<()> { - // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; - drop(insert_lock); - - let left_count = services - .rooms - .state_cache - .get_left_count(room_id, sender_user)?; - - // Left before last sync - if Some(since) >= left_count { - return Ok(()); - } - - if !services.rooms.metadata.exists(room_id)? { - // This is just a rejected invite, not a room we know - // Insert a leave event anyways - let event = PduEvent { - event_id: EventId::new(services.globals.server_name()).into(), - sender: sender_user.to_owned(), - origin: None, - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - kind: TimelineEventType::RoomMember, - content: serde_json::from_str(r#"{"membership":"leave"}"#).expect("this is valid JSON"), - state_key: Some(sender_user.to_string()), - unsigned: None, - // The following keys are dropped on conversion - room_id: room_id.to_owned(), - prev_events: vec![], - depth: uint!(1), - auth_events: vec![], - redacts: None, - hashes: EventHash { - sha256: String::new(), - }, - signatures: None, - }; - - left_rooms.insert( - room_id.to_owned(), - LeftRoom { - account_data: RoomAccountData { - events: Vec::new(), - }, - timeline: Timeline { - limited: false, - prev_batch: Some(next_batch_string.to_owned()), - events: Vec::new(), - }, - state: RoomState { - events: vec![event.to_sync_state_event()], - }, - }, - ); - return Ok(()); - } - - let mut left_state_events = Vec::new(); - - let since_shortstatehash = services - .rooms - .user - .get_token_shortstatehash(room_id, since)?; - - let since_state_ids = match since_shortstatehash { - Some(s) => services.rooms.state_accessor.state_full_ids(s).await?, - None => HashMap::new(), - }; - - let Some(left_event_id) = - services - .rooms - .state_accessor - .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())? - else { - error!("Left room but no left state event"); - return Ok(()); - }; - - let Some(left_shortstatehash) = services - .rooms - .state_accessor - .pdu_shortstatehash(&left_event_id)? - else { - error!(event_id = %left_event_id, "Leave event has no state"); - return Ok(()); - }; - - let mut left_state_ids = services - .rooms - .state_accessor - .state_full_ids(left_shortstatehash) - .await?; - - let leave_shortstatekey = services - .rooms - .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; - - left_state_ids.insert(leave_shortstatekey, left_event_id); - - let mut i: u8 = 0; - for (key, id) in left_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?; - - if !lazy_load_enabled - || event_type != StateEventType::RoomMember - || full_state - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - left_state_events.push(pdu.to_sync_state_event()); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } - } - - left_rooms.insert( - room_id.to_owned(), - LeftRoom { - account_data: RoomAccountData { - events: Vec::new(), - }, - timeline: Timeline { - limited: false, - prev_batch: Some(next_batch_string.to_owned()), - events: Vec::new(), - }, - state: RoomState { - events: left_state_events, - }, - }, - ); - Ok(()) -} - -async fn process_presence_updates( - services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, -) -> Result<()> { - // Take presence updates - for (user_id, _, presence_bytes) in services.presence.presence_since(since) { - if !services - .rooms - .state_cache - .user_sees_user(syncing_user, &user_id)? - { - continue; - } - - let presence_event = services - .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; - match presence_updates.entry(user_id) { - Entry::Vacant(slot) => { - slot.insert(presence_event); - }, - Entry::Occupied(mut slot) => { - let curr_event = slot.get_mut(); - let curr_content = &mut curr_event.content; - let new_content = presence_event.content; - - // Update existing presence event with more info - curr_content.presence = new_content.presence; - curr_content.status_msg = new_content - .status_msg - .or_else(|| curr_content.status_msg.take()); - curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago); - curr_content.displayname = new_content - .displayname - .or_else(|| curr_content.displayname.take()); - curr_content.avatar_url = new_content - .avatar_url - .or_else(|| curr_content.avatar_url.take()); - curr_content.currently_active = new_content - .currently_active - .or(curr_content.currently_active); - }, - } - } - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -async fn load_joined_room( - services: &Services, sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, - sincecount: PduCount, next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, - lazy_load_send_redundant: bool, full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>, - left_encrypted_users: &mut HashSet<OwnedUserId>, -) -> Result<JoinedRoom> { - // Get and drop the lock to wait for remaining operations to finish - // This will make sure the we have all events until next_batch - let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; - drop(insert_lock); - - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?; - - let send_notification_counts = !timeline_pdus.is_empty() - || services - .rooms - .user - .last_notification_read(sender_user, room_id)? - > since; - - let mut timeline_users = HashSet::new(); - for (_, event) in &timeline_pdus { - timeline_users.insert(event.sender.as_str().to_owned()); - } - - services - .rooms - .lazy_loading - .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) - .await?; - - // Database queries: - - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - return Err!(Database(error!("Room {room_id} has no state"))); - }; - - let since_shortstatehash = services - .rooms - .user - .get_token_shortstatehash(room_id, since)?; - - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); - - // Recalculate heroes (first 5 members) - let mut heroes: Vec<OwnedUserId> = Vec::with_capacity(5); - - if joined_member_count.saturating_add(invited_member_count) <= 5 { - // Go through all PDUs and for each member event, check if the user is still - // joined or invited until we have 5 or we reach the end - - for hero in services - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(Result::ok) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - // The membership was and still is invite or join - if matches!(content.membership, MembershipState::Join | MembershipState::Invite) - && (services.rooms.state_cache.is_joined(&user_id, room_id)? - || services.rooms.state_cache.is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(user_id)) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - .filter_map(Result::ok) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user { - continue; - } - - heroes.push(hero); - } - } - - Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) - }; - - let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); - - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync - - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services - .rooms - .short - .get_statekey_from_short(shortstatekey)?; - - if event_type != StateEventType::RoomMember { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - state_events.push(pdu); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled - || full_state - || timeline_users.contains(&state_key) - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } - - // Reset lazy loading because this is an initial sync - services - .rooms - .lazy_loading - .lazy_load_reset(sender_user, sender_device, room_id)?; - - // The state_events above should contain all timeline_users, let's mark them as - // lazy loaded. - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - (heroes, joined_member_count, invited_member_count, true, state_events) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); - - let mut delta_state_events = Vec::new(); - - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; - - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - - delta_state_events.push(pdu); - tokio::task::yield_now().await; - } - } - } - - let encrypted_room = services - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); - - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; - - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); - - let send_member_count = delta_state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); - - if encrypted_room { - for state_event in &delta_state_events { - if state_event.kind != TimelineEventType::RoomMember { - continue; - } - - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - if user_id == sender_user { - continue; - } - - let new_membership = - serde_json::from_str::<RoomMemberEventContent>(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; - - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(services, sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } - } - } - } - - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } - - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; - - let mut state_events = delta_state_events; - let mut lazy_loaded = HashSet::new(); - - // Mark all member events we're returning as lazy-loaded - for pdu in &state_events { - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - }, - Err(e) => error!("Invalid state key for member event: {}", e), - } - } - } - - // Fetch contextual member state events for events from the timeline, and - // mark them as lazy-loaded as well. - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } - - if !services.rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant - { - if let Some(member_event) = services.rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); - } - } - } - - services - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; - - // Look for device list updates in this room - device_list_updates.extend( - services - .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok), - ); - - let notification_count = if send_notification_counts { - Some( - services - .rooms - .user - .notification_count(sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ) - } else { - None - }; - - let highlight_count = if send_notification_counts { - Some( - services - .rooms - .user - .highlight_count(sender_user, room_id)? - .try_into() - .expect("highlight count can't go that high"), - ) - } else { - None - }; - - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })?; - - let room_events: Vec<_> = timeline_pdus - .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); - - let mut edus: Vec<_> = services - .rooms - .read_receipt - .readreceipts_since(room_id, since) - .filter_map(Result::ok) // Filter out buggy events - .map(|(_, _, v)| v) - .collect(); - - if services.rooms.typing.last_typing_update(room_id).await? > since { - edus.push( - serde_json::from_str( - &serde_json::to_string(&services.rooms.typing.typings_all(room_id).await?) - .expect("event is valid, we just created it"), - ) - .expect("event is valid, we just created it"), - ); - } - - // Save the state after this sync so we can send the correct state diff next - // sync - services - .rooms - .user - .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; - - Ok(JoinedRoom { - account_data: RoomAccountData { - events: services - .account_data - .changes_since(Some(room_id), sender_user, since)? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) - .collect(), - }, - summary: RoomSummary { - heroes, - joined_member_count: joined_member_count.map(ruma_from_u64), - invited_member_count: invited_member_count.map(ruma_from_u64), - }, - unread_notifications: UnreadNotificationsCount { - highlight_count, - notification_count, - }, - timeline: Timeline { - limited: limited || joined_since_last_sync, - prev_batch, - events: room_events, - }, - state: RoomState { - events: state_events - .iter() - .map(|pdu| pdu.to_sync_state_event()) - .collect(), - }, - ephemeral: Ephemeral { - events: edus, - }, - unread_thread_notifications: BTreeMap::new(), - }) -} - -fn load_timeline( - services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, -) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { - let timeline_pdus; - let limited = if services - .rooms - .timeline - .last_timeline_count(sender_user, room_id)? - > roomsincecount - { - let mut non_timeline_pdus = services - .rooms - .timeline - .pdus_until(sender_user, room_id, PduCount::max())? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .take_while(|(pducount, _)| pducount > &roomsincecount); - - // Take the last events for the timeline - timeline_pdus = non_timeline_pdus - .by_ref() - .take(usize_from_u64_truncated(limit)) - .collect::<Vec<_>>() - .into_iter() - .rev() - .collect::<Vec<_>>(); - - // They /sync response doesn't always return all messages, so we say the output - // is limited unless there are events in non_timeline_pdus - non_timeline_pdus.next().is_some() - } else { - timeline_pdus = Vec::new(); - false - }; - Ok((timeline_pdus, limited)) -} - -fn share_encrypted_room( - services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, -) -> Result<bool> { - Ok(services - .rooms - .user - .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? - .filter_map(Result::ok) - .filter(|room_id| room_id != ignore_room) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .any(|encrypted| encrypted)) -} - -/// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` -/// -/// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) -pub(crate) async fn sync_events_v4_route( - State(services): State<crate::State>, body: Ruma<sync_events::v4::Request>, -) -> Result<sync_events::v4::Response> { - let sender_user = body.sender_user.expect("user is authenticated"); - let sender_device = body.sender_device.expect("user is authenticated"); - let mut body = body.body; - // Setup watchers, so if there's no response, we can wait for them - let watcher = services.globals.watch(&sender_user, &sender_device); - - let next_batch = services.globals.next_count()?; - - let conn_id = body - .conn_id - .clone() - .unwrap_or_else(|| SINGLE_CONNECTION_SYNC.to_owned()); - - let globalsince = body - .pos - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); - - if globalsince != 0 - && !services - .users - .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) - { - debug!("Restarting sync stream because it was gone from the database"); - return Err(Error::Request( - ErrorKind::UnknownPos, - "Connection data lost since last time".into(), - http::StatusCode::BAD_REQUEST, - )); - } - - if globalsince == 0 { - services - .users - .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); - } - - // Get sticky parameters from cache - let known_rooms = - services - .users - .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - - let all_joined_rooms = services - .rooms - .state_cache - .rooms_joined(&sender_user) - .filter_map(Result::ok) - .collect::<Vec<_>>(); - - let all_invited_rooms = services - .rooms - .state_cache - .rooms_invited(&sender_user) - .filter_map(Result::ok) - .map(|r| r.0) - .collect::<Vec<_>>(); - - let all_rooms = all_joined_rooms - .iter() - .cloned() - .chain(all_invited_rooms.clone()) - .collect(); - - if body.extensions.to_device.enabled.unwrap_or(false) { - services - .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; - } - - let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in - let mut device_list_changes = HashSet::new(); - let mut device_list_left = HashSet::new(); - - let mut receipts = sync_events::v4::Receipts { - rooms: BTreeMap::new(), - }; - - let mut account_data = sync_events::v4::AccountData { - global: Vec::new(), - rooms: BTreeMap::new(), - }; - if body.extensions.account_data.enabled.unwrap_or(false) { - account_data.global = services - .account_data - .changes_since(None, &sender_user, globalsince)? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) - .collect(); - - if let Some(rooms) = body.extensions.account_data.rooms { - for room in rooms { - account_data.rooms.insert( - room.clone(), - services - .account_data - .changes_since(Some(&room), &sender_user, globalsince)? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) - .collect(), - ); - } - } - } - - if body.extensions.e2ee.enabled.unwrap_or(false) { - // Look for device list updates of this account - device_list_changes.extend( - services - .users - .keys_changed(sender_user.as_ref(), globalsince, None) - .filter_map(Result::ok), - ); - - for room_id in &all_joined_rooms { - let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); - continue; - }; - - let since_shortstatehash = services - .rooms - .user - .get_token_shortstatehash(room_id, globalsince)?; - - let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash - .and_then(|shortstatehash| { - services - .rooms - .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); - - let encrypted_room = services - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); - - if let Some(since_shortstatehash) = since_shortstatehash { - // Skip if there are only timeline changes - if since_shortstatehash == current_shortstatehash { - continue; - } - - let since_encryption = services.rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; - - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - - let new_encrypted_room = encrypted_room && since_encryption.is_none(); - if encrypted_room { - let current_state_ids = services - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; - - for (key, id) in current_state_ids { - if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { - error!("Pdu in state not found: {}", id); - continue; - }; - if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - if user_id == sender_user { - continue; - } - - let new_membership = - serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; - - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(&services, &sender_user, &user_id, room_id)? { - device_list_changes.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, - } - } - } - } - } - if joined_since_last_sync || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_changes.extend( - services - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - &sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(&services, &sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } - } - } - // Look for device list updates in this room - device_list_changes.extend( - services - .users - .keys_changed(room_id.as_ref(), globalsince, None) - .filter_map(Result::ok), - ); - } - for user_id in left_encrypted_users { - let dont_share_encrypted_room = services - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(Result::ok) - .filter_map(|other_room_id| { - Some( - services - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need - // to tell them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } - } - - let mut lists = BTreeMap::new(); - let mut todo_rooms = BTreeMap::new(); // and required state - - for (list_id, list) in body.lists { - let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { - Some(true) => &all_invited_rooms, - Some(false) => &all_joined_rooms, - None => &all_rooms, - }; - - let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { - Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(active_rooms, State(services), &value, true), - None => active_rooms.clone(), - }; - - let active_rooms = match list.filters.clone().map(|f| f.room_types) { - Some(filter) if filter.is_empty() => active_rooms.clone(), - Some(value) => filter_rooms(&active_rooms, State(services), &value, false), - None => active_rooms, - }; - - let mut new_known_rooms = BTreeSet::new(); - - lists.insert( - list_id.clone(), - sync_events::v4::SyncList { - ops: list - .ranges - .into_iter() - .map(|mut r| { - r.0 = r.0.clamp( - uint!(0), - UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), - ); - r.1 = - r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); - let room_ids = if !active_rooms.is_empty() { - active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() - } else { - Vec::new() - }; - new_known_rooms.extend(room_ids.iter().cloned()); - for room_id in &room_ids { - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = list - .room_details - .timeline_limit - .map_or(10, u64::from) - .min(100); - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get(&list_id) - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - } - sync_events::v4::SyncOp { - op: SlidingOp::Sync, - range: Some(r), - index: None, - room_ids, - room_id: None, - } - }) - .collect(), - count: ruma_from_usize(active_rooms.len()), - }, - ); - - if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - list_id, - new_known_rooms, - globalsince, - ); - } - } - - let mut known_subscription_rooms = BTreeSet::new(); - for (room_id, room) in &body.room_subscriptions { - if !services.rooms.metadata.exists(room_id)? { - continue; - } - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = room.timeline_limit.map_or(10, u64::from).min(100); - todo_room.0.extend(room.required_state.iter().cloned()); - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get("subscriptions") - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - known_subscription_rooms.insert(room_id.clone()); - } - - for r in body.unsubscribe_rooms { - known_subscription_rooms.remove(&r); - body.room_subscriptions.remove(&r); - } - - if let Some(conn_id) = &body.conn_id { - services.users.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - "subscriptions".to_owned(), - known_subscription_rooms, - globalsince, - ); - } - - if let Some(conn_id) = &body.conn_id { - services.users.update_sync_subscriptions( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - body.room_subscriptions, - ); - } - - let mut rooms = BTreeMap::new(); - for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { - let roomsincecount = PduCount::Normal(*roomsince); - - let mut timestamp: Option<_> = None; - let mut invite_state = None; - let (timeline_pdus, limited); - if all_invited_rooms.contains(room_id) { - // TODO: figure out a timestamp we can use for remote invites - invite_state = services - .rooms - .state_cache - .invite_state(&sender_user, room_id) - .unwrap_or(None); - - (timeline_pdus, limited) = (Vec::new(), true); - } else { - (timeline_pdus, limited) = - match load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit) { - Ok(value) => value, - Err(err) => { - warn!("Encountered missing timeline in {}, error {}", room_id, err); - continue; - }, - }; - } - - account_data.rooms.insert( - room_id.clone(), - services - .account_data - .changes_since(Some(room_id), &sender_user, *roomsince)? - .into_iter() - .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) - .collect(), - ); - - let room_receipts = services - .rooms - .read_receipt - .readreceipts_since(room_id, *roomsince); - let vector: Vec<_> = room_receipts.into_iter().collect(); - let receipt_size = vector.len(); - receipts - .rooms - .insert(room_id.clone(), pack_receipts(Box::new(vector.into_iter()))); - - if roomsince != &0 - && timeline_pdus.is_empty() - && account_data.rooms.get(room_id).is_some_and(Vec::is_empty) - && receipt_size == 0 - { - continue; - } - - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })? - .or_else(|| { - if roomsince != &0 { - Some(roomsince.to_string()) - } else { - None - } - }); - - let room_events: Vec<_> = timeline_pdus - .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); - - for (_, pdu) in timeline_pdus { - let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); - if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && !timestamp.is_some_and(|time| time > ts) { - timestamp = Some(ts); - } - } - - let required_state = required_state_request - .iter() - .map(|state| { - services - .rooms - .state_accessor - .room_state_get(room_id, &state.0, &state.1) - }) - .filter_map(Result::ok) - .flatten() - .map(|state| state.to_sync_state_event()) - .collect(); - - // Heroes - let heroes = services - .rooms - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member != &sender_user) - .map(|member| { - Ok::<_, Error>( - services - .rooms - .state_accessor - .get_member(room_id, &member)? - .map(|memberevent| SlidingSyncRoomHero { - user_id: member, - name: memberevent.displayname, - avatar: memberevent.avatar_url, - }), - ) - }) - .filter_map(Result::ok) - .flatten() - .take(5) - .collect::<Vec<_>>(); - let name = match heroes.len().cmp(&(1_usize)) { - Ordering::Greater => { - let firsts = heroes[1..] - .iter() - .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) - .collect::<Vec<_>>() - .join(", "); - let last = heroes[0] - .name - .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()); - Some(format!("{firsts} and {last}")) - }, - Ordering::Equal => Some( - heroes[0] - .name - .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()), - ), - Ordering::Less => None, - }; - - let heroes_avatar = if heroes.len() == 1 { - heroes[0].avatar.clone() - } else { - None - }; - - rooms.insert( - room_id.clone(), - sync_events::v4::SlidingSyncRoom { - name: services.rooms.state_accessor.get_name(room_id)?.or(name), - avatar: if let Some(heroes_avatar) = heroes_avatar { - ruma::JsOption::Some(heroes_avatar) - } else { - match services.rooms.state_accessor.get_avatar(room_id)? { - ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), - ruma::JsOption::Null => ruma::JsOption::Null, - ruma::JsOption::Undefined => ruma::JsOption::Undefined, - } - }, - initial: Some(roomsince == &0), - is_dm: None, - invite_state, - unread_notifications: UnreadNotificationsCount { - highlight_count: Some( - services - .rooms - .user - .highlight_count(&sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ), - notification_count: Some( - services - .rooms - .user - .notification_count(&sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ), - }, - timeline: room_events, - required_state, - prev_batch, - limited, - joined_count: Some( - services - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0) - .try_into() - .unwrap_or_else(|_| uint!(0)), - ), - invited_count: Some( - services - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0) - .try_into() - .unwrap_or_else(|_| uint!(0)), - ), - num_live: None, // Count events in timeline greater than global sync counter - timestamp, - heroes: Some(heroes), - }, - ); - } - - if rooms - .iter() - .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) - { - // Hang a few seconds so requests are not spammed - // Stop hanging if new info arrives - let default = Duration::from_secs(30); - let duration = cmp::min(body.timeout.unwrap_or(default), default); - _ = tokio::time::timeout(duration, watcher).await; - } - - Ok(sync_events::v4::Response { - initial: globalsince == 0, - txn_id: body.txn_id.clone(), - pos: next_batch.to_string(), - lists, - rooms, - extensions: sync_events::v4::Extensions { - to_device: if body.extensions.to_device.enabled.unwrap_or(false) { - Some(sync_events::v4::ToDevice { - events: services - .users - .get_to_device_events(&sender_user, &sender_device)?, - next_batch: next_batch.to_string(), - }) - } else { - None - }, - e2ee: sync_events::v4::E2EE { - device_lists: DeviceLists { - changed: device_list_changes.into_iter().collect(), - left: device_list_left.into_iter().collect(), - }, - device_one_time_keys_count: services - .users - .count_one_time_keys(&sender_user, &sender_device)?, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, - }, - account_data, - receipts, - typing: sync_events::v4::Typing { - rooms: BTreeMap::new(), - }, - }, - delta_token: None, - }) -} - -fn filter_rooms( - rooms: &[OwnedRoomId], State(services): State<crate::State>, filter: &[RoomTypeFilter], negate: bool, -) -> Vec<OwnedRoomId> { - return rooms - .iter() - .filter(|r| match services.rooms.state_accessor.get_room_type(r) { - Err(e) => { - warn!("Requested room type for {}, but could not retrieve with error {}", r, e); - false - }, - Ok(result) => { - let result = RoomTypeFilter::from(result); - if negate { - !filter.contains(&result) - } else { - filter.is_empty() || filter.contains(&result) - } - }, - }) - .cloned() - .collect(); -} diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..3201b8276a22a4015f17f7b8b6d94baf2c7a777c --- /dev/null +++ b/src/api/client/sync/mod.rs @@ -0,0 +1,63 @@ +mod v3; +mod v4; + +use conduit::{utils::ReadyExt, PduCount}; +use futures::StreamExt; +use ruma::{RoomId, UserId}; + +pub(crate) use self::{v3::sync_events_route, v4::sync_events_v4_route}; +use crate::{service::Services, Error, PduEvent, Result}; + +async fn load_timeline( + services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: usize, +) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { + let last_timeline_count = services + .rooms + .timeline + .last_timeline_count(Some(sender_user), room_id) + .await?; + + if last_timeline_count <= roomsincecount { + return Ok((Vec::new(), false)); + } + + let mut non_timeline_pdus = services + .rooms + .timeline + .pdus_rev(Some(sender_user), room_id, None) + .await? + .ready_take_while(|(pducount, _)| *pducount > roomsincecount); + + // Take the last events for the timeline + let timeline_pdus: Vec<_> = non_timeline_pdus + .by_ref() + .take(limit) + .collect::<Vec<_>>() + .await + .into_iter() + .rev() + .collect(); + + // They /sync response doesn't always return all messages, so we say the output + // is limited unless there are events in non_timeline_pdus + let limited = non_timeline_pdus.next().await.is_some(); + + Ok((timeline_pdus, limited)) +} + +async fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: Option<&RoomId>, +) -> bool { + services + .rooms + .user + .get_shared_rooms(sender_user, user_id) + .ready_filter(|&room_id| Some(room_id) != ignore_room) + .any(|other_room_id| { + services + .rooms + .state_accessor + .is_encrypted_room(other_room_id) + }) + .await +} diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs new file mode 100644 index 0000000000000000000000000000000000000000..ea487d8e296ade5cade896637ad60e1817d3b560 --- /dev/null +++ b/src/api/client/sync/v3.rs @@ -0,0 +1,1056 @@ +use std::{ + cmp::{self}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + time::Duration, +}; + +use axum::extract::State; +use conduit::{ + at, err, error, extract_variant, is_equal_to, + result::FlatOk, + utils::{math::ruma_from_u64, BoolExt, IterStream, ReadyExt, TryFutureExtExt}, + PduCount, +}; +use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt}; +use ruma::{ + api::client::{ + filter::{FilterDefinition, LazyLoadOptions}, + sync::sync_events::{ + self, + v3::{ + Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, + RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, + }, + DeviceLists, UnreadNotificationsCount, + }, + uiaa::UiaaResponse, + }, + events::{ + presence::PresenceEvent, + room::member::{MembershipState, RoomMemberEventContent}, + AnyRawAccountDataEvent, AnySyncEphemeralRoomEvent, StateEventType, + TimelineEventType::*, + }, + serde::Raw, + uint, DeviceId, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId, +}; +use tracing::{Instrument as _, Span}; + +use super::{load_timeline, share_encrypted_room}; +use crate::{ + service::{pdu::EventHash, Services}, + utils, Error, PduEvent, Result, Ruma, RumaResponse, +}; + +/// # `GET /_matrix/client/r0/sync` +/// +/// Synchronize the client's state with the latest state on the server. +/// +/// - This endpoint takes a `since` parameter which should be the `next_batch` +/// value from a previous request for incremental syncs. +/// +/// Calling this endpoint without a `since` parameter returns: +/// - Some of the most recent events of each timeline +/// - Notification counts for each room +/// - Joined and invited member counts, heroes +/// - All state events +/// +/// Calling this endpoint with a `since` parameter from a previous `next_batch` +/// returns: For joined rooms: +/// - Some of the most recent events of each timeline that happened after since +/// - If user joined the room after since: All state events (unless lazy loading +/// is activated) and all device list updates in that room +/// - If the user was already in the room: A list of all events that are in the +/// state now, but were not in the state at `since` +/// - If the state we send contains a member event: Joined and invited member +/// counts, heroes +/// - Device list updates that happened after `since` +/// - If there are events in the timeline we send or the user send updated his +/// read mark: Notification counts +/// - EDUs that are active now (read receipts, typing updates, presence) +/// - TODO: Allow multiple sync streams to support Pantalaimon +/// +/// For invited rooms: +/// - If the user was invited after `since`: A subset of the state of the room +/// at the point of the invite +/// +/// For left rooms: +/// - If the user left after `since`: `prev_batch` token, empty state (TODO: +/// subset of the state at the point of the leave) +pub(crate) async fn sync_events_route( + State(services): State<crate::State>, body: Ruma<sync_events::v3::Request>, +) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { + let sender_user = body.sender_user.expect("user is authenticated"); + let sender_device = body.sender_device.expect("user is authenticated"); + let body = body.body; + + // Presence update + if services.globals.allow_local_presence() { + services + .presence + .ping_presence(&sender_user, &body.set_presence) + .await?; + } + + // Setup watchers, so if there's no response, we can wait for them + let watcher = services.sync.watch(&sender_user, &sender_device); + + let next_batch = services.globals.current_count()?; + let next_batchcount = PduCount::Normal(next_batch); + let next_batch_string = next_batch.to_string(); + + // Load filter + let filter = match body.filter { + None => FilterDefinition::default(), + Some(Filter::FilterDefinition(filter)) => filter, + Some(Filter::FilterId(filter_id)) => services + .users + .get_filter(&sender_user, &filter_id) + .await + .unwrap_or_default(), + }; + + // some clients, at least element, seem to require knowledge of redundant + // members for "inline" profiles on the timeline to work properly + let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members, + } => (true, include_redundant_members), + LazyLoadOptions::Disabled => (false, cfg!(feature = "element_hacks")), + }; + + let full_state = body.full_state; + + let mut joined_rooms = BTreeMap::new(); + let since = body + .since + .as_ref() + .and_then(|string| string.parse().ok()) + .unwrap_or(0); + let sincecount = PduCount::Normal(since); + + let mut presence_updates = HashMap::new(); + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in + let mut device_list_updates = HashSet::new(); + let mut device_list_left = HashSet::new(); + + // Look for device list updates of this account + device_list_updates.extend( + services + .users + .keys_changed(&sender_user, since, None) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, + ); + + if services.globals.allow_local_presence() { + process_presence_updates(&services, &mut presence_updates, since, &sender_user).await?; + } + + let all_joined_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_joined(&sender_user) + .map(ToOwned::to_owned) + .collect() + .await; + + // Coalesce database writes for the remainder of this scope. + let _cork = services.db.cork_and_flush(); + + for room_id in all_joined_rooms { + if let Ok(joined_room) = load_joined_room( + &services, + &sender_user, + &sender_device, + &room_id, + since, + sincecount, + next_batch, + next_batchcount, + lazy_load_enabled, + lazy_load_send_redundant, + full_state, + &mut device_list_updates, + &mut left_encrypted_users, + ) + .await + { + if !joined_room.is_empty() { + joined_rooms.insert(room_id.clone(), joined_room); + } + } + } + + let mut left_rooms = BTreeMap::new(); + let all_left_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_left(&sender_user) + .collect() + .await; + + for result in all_left_rooms { + handle_left_room( + &services, + since, + &result.0, + &sender_user, + &mut left_rooms, + &next_batch_string, + full_state, + lazy_load_enabled, + ) + .instrument(Span::current()) + .await?; + } + + let mut invited_rooms = BTreeMap::new(); + let all_invited_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_invited(&sender_user) + .collect() + .await; + + for (room_id, invite_state_events) in all_invited_rooms { + // Get and drop the lock to wait for remaining operations to finish + let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; + drop(insert_lock); + + let invite_count = services + .rooms + .state_cache + .get_invite_count(&room_id, &sender_user) + .await + .ok(); + + // Invited before last sync + if Some(since) >= invite_count { + continue; + } + + invited_rooms.insert( + room_id.clone(), + InvitedRoom { + invite_state: InviteState { + events: invite_state_events, + }, + }, + ); + } + + for user_id in left_encrypted_users { + let dont_share_encrypted_room = !share_encrypted_room(&services, &sender_user, &user_id, None).await; + + // If the user doesn't share an encrypted room with the target anymore, we need + // to tell them + if dont_share_encrypted_room { + device_list_left.insert(user_id); + } + } + + // Remove all to-device events the device received *last time* + services + .users + .remove_to_device_events(&sender_user, &sender_device, since) + .await; + + let response = sync_events::v3::Response { + next_batch: next_batch_string, + rooms: Rooms { + leave: left_rooms, + join: joined_rooms, + invite: invited_rooms, + knock: BTreeMap::new(), // TODO + }, + presence: Presence { + events: presence_updates + .into_values() + .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) + .collect(), + }, + account_data: GlobalAccountData { + events: services + .account_data + .changes_since(None, &sender_user, since) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) + .collect(), + }, + device_lists: DeviceLists { + changed: device_list_updates.into_iter().collect(), + left: device_list_left.into_iter().collect(), + }, + device_one_time_keys_count: services + .users + .count_one_time_keys(&sender_user, &sender_device) + .await, + to_device: ToDevice { + events: services + .users + .get_to_device_events(&sender_user, &sender_device) + .collect() + .await, + }, + // Fallback keys are not yet supported + device_unused_fallback_key_types: None, + }; + + // TODO: Retry the endpoint instead of returning + if !full_state + && response.rooms.is_empty() + && response.presence.is_empty() + && response.account_data.is_empty() + && response.device_lists.is_empty() + && response.to_device.is_empty() + { + // Hang a few seconds so requests are not spammed + // Stop hanging if new info arrives + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; + } + + Ok(response) +} + +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip_all, fields(user_id = %sender_user, room_id = %room_id), name = "left_room")] +async fn handle_left_room( + services: &Services, since: u64, room_id: &RoomId, sender_user: &UserId, + left_rooms: &mut BTreeMap<OwnedRoomId, LeftRoom>, next_batch_string: &str, full_state: bool, + lazy_load_enabled: bool, +) -> Result<()> { + // Get and drop the lock to wait for remaining operations to finish + let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; + drop(insert_lock); + + let left_count = services + .rooms + .state_cache + .get_left_count(room_id, sender_user) + .await + .ok(); + + // Left before last sync + if Some(since) >= left_count { + return Ok(()); + } + + if !services.rooms.metadata.exists(room_id).await { + // This is just a rejected invite, not a room we know + // Insert a leave event anyways + let event = PduEvent { + event_id: EventId::new(services.globals.server_name()).into(), + sender: sender_user.to_owned(), + origin: None, + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + kind: RoomMember, + content: serde_json::from_str(r#"{"membership":"leave"}"#).expect("this is valid JSON"), + state_key: Some(sender_user.to_string()), + unsigned: None, + // The following keys are dropped on conversion + room_id: room_id.to_owned(), + prev_events: vec![], + depth: uint!(1), + auth_events: vec![], + redacts: None, + hashes: EventHash { + sha256: String::new(), + }, + signatures: None, + }; + + left_rooms.insert( + room_id.to_owned(), + LeftRoom { + account_data: RoomAccountData { + events: Vec::new(), + }, + timeline: Timeline { + limited: false, + prev_batch: Some(next_batch_string.to_owned()), + events: Vec::new(), + }, + state: RoomState { + events: vec![event.to_sync_state_event()], + }, + }, + ); + return Ok(()); + } + + let mut left_state_events = Vec::new(); + + let since_shortstatehash = services + .rooms + .user + .get_token_shortstatehash(room_id, since) + .await; + + let since_state_ids = match since_shortstatehash { + Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?, + Err(_) => HashMap::new(), + }; + + let Ok(left_event_id) = services + .rooms + .state_accessor + .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str()) + .await + else { + error!("Left room but no left state event"); + return Ok(()); + }; + + let Ok(left_shortstatehash) = services + .rooms + .state_accessor + .pdu_shortstatehash(&left_event_id) + .await + else { + error!(event_id = %left_event_id, "Leave event has no state"); + return Ok(()); + }; + + let mut left_state_ids = services + .rooms + .state_accessor + .state_full_ids(left_shortstatehash) + .await?; + + let leave_shortstatekey = services + .rooms + .short + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str()) + .await; + + left_state_ids.insert(leave_shortstatekey, left_event_id); + + for (shortstatekey, event_id) in left_state_ids { + if full_state || since_state_ids.get(&shortstatekey) != Some(&event_id) { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; + + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 + if !lazy_load_enabled + || event_type != StateEventType::RoomMember + || full_state + || (cfg!(feature = "element_hacks") && *sender_user == state_key) + { + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); + continue; + }; + + left_state_events.push(pdu.to_sync_state_event()); + } + } + } + + left_rooms.insert( + room_id.to_owned(), + LeftRoom { + account_data: RoomAccountData { + events: Vec::new(), + }, + timeline: Timeline { + limited: false, + prev_batch: Some(next_batch_string.to_owned()), + events: Vec::new(), + }, + state: RoomState { + events: left_state_events, + }, + }, + ); + Ok(()) +} + +async fn process_presence_updates( + services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, +) -> Result<()> { + let presence_since = services.presence.presence_since(since); + + // Take presence updates + pin_mut!(presence_since); + while let Some((user_id, _, presence_bytes)) = presence_since.next().await { + if !services + .rooms + .state_cache + .user_sees_user(syncing_user, user_id) + .await + { + continue; + } + + let presence_event = services + .presence + .from_json_bytes_to_event(presence_bytes, user_id) + .await?; + + match presence_updates.entry(user_id.into()) { + Entry::Vacant(slot) => { + slot.insert(presence_event); + }, + Entry::Occupied(mut slot) => { + let curr_event = slot.get_mut(); + let curr_content = &mut curr_event.content; + let new_content = presence_event.content; + + // Update existing presence event with more info + curr_content.presence = new_content.presence; + curr_content.status_msg = new_content + .status_msg + .or_else(|| curr_content.status_msg.take()); + curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago); + curr_content.displayname = new_content + .displayname + .or_else(|| curr_content.displayname.take()); + curr_content.avatar_url = new_content + .avatar_url + .or_else(|| curr_content.avatar_url.take()); + curr_content.currently_active = new_content + .currently_active + .or(curr_content.currently_active); + }, + }; + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn load_joined_room( + services: &Services, sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, + sincecount: PduCount, next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, + lazy_load_send_redundant: bool, full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>, + left_encrypted_users: &mut HashSet<OwnedUserId>, +) -> Result<JoinedRoom> { + // Get and drop the lock to wait for remaining operations to finish + // This will make sure the we have all events until next_batch + let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; + drop(insert_lock); + + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10_usize).await?; + + let send_notification_counts = !timeline_pdus.is_empty() + || services + .rooms + .user + .last_notification_read(sender_user, room_id) + .await > since; + + let mut timeline_users = HashSet::new(); + for (_, event) in &timeline_pdus { + timeline_users.insert(event.sender.as_str().to_owned()); + } + + services + .rooms + .lazy_loading + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount); + + // Database queries: + + let current_shortstatehash = services + .rooms + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Database(error!("Room {room_id} has no state"))))?; + + let since_shortstatehash = services + .rooms + .user + .get_token_shortstatehash(room_id, since) + .await + .ok(); + + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus + .is_empty() + && (since_shortstatehash.is_none() || since_shortstatehash.is_some_and(is_equal_to!(current_shortstatehash))) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || async { + let joined_member_count = services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0); + + let invited_member_count = services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0); + + if joined_member_count.saturating_add(invited_member_count) > 5 { + return Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), Vec::new())); + } + + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end + + // Recalculate heroes (first 5 members) + let heroes = services + .rooms + .timeline + .all_pdus(sender_user, room_id) + .await? + .ready_filter(|(_, pdu)| pdu.kind == RoomMember) + .filter_map(|(_, pdu)| async move { + let content: RoomMemberEventContent = pdu.get_content().ok()?; + let user_id: &UserId = pdu.state_key.as_deref().map(TryInto::try_into).flat_ok()?; + + if user_id == sender_user { + return None; + } + + // The membership was and still is invite or join + if !matches!(content.membership, MembershipState::Join | MembershipState::Invite) { + return None; + } + + let is_invited = services.rooms.state_cache.is_invited(user_id, room_id); + + let is_joined = services.rooms.state_cache.is_joined(user_id, room_id); + + if !is_joined.await && is_invited.await { + return None; + } + + Some(user_id.to_owned()) + }) + .collect::<HashSet<OwnedUserId>>() + .await; + + Ok::<_, Error>(( + Some(joined_member_count), + Some(invited_member_count), + heroes.into_iter().collect::<Vec<_>>(), + )) + }; + + let get_sender_member_content = |short| { + services + .rooms + .state_accessor + .state_get_content(short, &StateEventType::RoomMember, sender_user.as_str()) + .ok() + }; + + let since_sender_member: OptionFuture<_> = since_shortstatehash.map(get_sender_member_content).into(); + + let joined_since_last_sync = since_sender_member + .await + .flatten() + .map_or(true, |content: RoomMemberEventContent| { + content.membership != MembershipState::Join + }); + + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync + + let (joined_member_count, invited_member_count, heroes) = calculate_counts().await?; + + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + for (shortstatekey, event_id) in current_state_ids { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; + + if event_type != StateEventType::RoomMember { + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); + continue; + }; + + state_events.push(pdu); + continue; + } + + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 + if !lazy_load_enabled + || full_state || timeline_users.contains(&state_key) + || (cfg!(feature = "element_hacks") && *sender_user == state_key) + { + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); + continue; + }; + + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + + state_events.push(pdu); + } + } + + // Reset lazy loading because this is an initial sync + services + .rooms + .lazy_loading + .lazy_load_reset(sender_user, sender_device, room_id) + .await; + + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); + + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.expect("missing since_shortstatehash on incremental sync"); + + let mut delta_state_events = Vec::new(); + + if since_shortstatehash != current_shortstatehash { + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let since_state_ids = services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; + + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + + delta_state_events.push(pdu); + tokio::task::yield_now().await; + } + } + } + + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); + + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; + + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + + let send_member_count = delta_state_events + .iter() + .any(|event| event.kind == RoomMember); + + if encrypted_room { + for state_event in &delta_state_events { + if state_event.kind != RoomMember { + continue; + } + + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + if user_id == sender_user { + continue; + } + + let content: RoomMemberEventContent = state_event.get_content()?; + + match content.membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(services, sender_user, &user_id, Some(room_id)).await { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } + + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services + .rooms + .state_cache + .room_members(room_id) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != *user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) + }) + .collect::<Vec<_>>() + .await, + ); + } + + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts().await? + } else { + (None, None, Vec::new()) + }; + + let mut state_events = delta_state_events; + let mut lazy_loaded = HashSet::new(); + + // Mark all member events we're returning as lazy-loaded + for pdu in &state_events { + if pdu.kind == RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), + } + } + } + + // Fetch contextual member state events for events from the timeline, and + // mark them as lazy-loaded as well. + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } + + if !services + .rooms + .lazy_loading + .lazy_load_was_sent_before(sender_user, sender_device, room_id, &event.sender) + .await || lazy_load_send_redundant + { + if let Ok(member_event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, event.sender.as_str()) + .await + { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); + } + } + } + + services.rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); + + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; + + // Look for device list updates in this room + device_list_updates.extend( + services + .users + .room_keys_changed(room_id, since, None) + .map(|(user_id, _)| user_id) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, + ); + + let notification_count = if send_notification_counts { + Some( + services + .rooms + .user + .notification_count(sender_user, room_id) + .await + .try_into() + .expect("notification count can't go that high"), + ) + } else { + None + }; + + let highlight_count = if send_notification_counts { + Some( + services + .rooms + .user + .highlight_count(sender_user, room_id) + .await + .try_into() + .expect("highlight count can't go that high"), + ) + } else { + None + }; + + let prev_batch = timeline_pdus + .first() + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() + .map(ToString::to_string); + + let room_events: Vec<_> = timeline_pdus + .iter() + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote + | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_sync_room_event()) + }) + .collect() + .await; + + let edus: HashMap<OwnedUserId, Raw<AnySyncEphemeralRoomEvent>> = services + .rooms + .read_receipt + .readreceipts_since(room_id, since) + .filter_map(|(read_user, _, edu)| async move { + services + .users + .user_is_ignored(&read_user, sender_user) + .await + .or_some((read_user, edu)) + }) + .collect() + .await; + + let mut edus: Vec<Raw<AnySyncEphemeralRoomEvent>> = edus.into_values().collect(); + + if services.rooms.typing.last_typing_update(room_id).await? > since { + edus.push( + serde_json::from_str( + &serde_json::to_string( + &services + .rooms + .typing + .typings_all(room_id, sender_user) + .await?, + ) + .expect("event is valid, we just created it"), + ) + .expect("event is valid, we just created it"), + ); + } + + // Save the state after this sync so we can send the correct state diff next + // sync + services + .rooms + .user + .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash) + .await; + + Ok(JoinedRoom { + account_data: RoomAccountData { + events: services + .account_data + .changes_since(Some(room_id), sender_user, since) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) + .collect(), + }, + summary: RoomSummary { + heroes, + joined_member_count: joined_member_count.map(ruma_from_u64), + invited_member_count: invited_member_count.map(ruma_from_u64), + }, + unread_notifications: UnreadNotificationsCount { + highlight_count, + notification_count, + }, + timeline: Timeline { + limited: limited || joined_since_last_sync, + prev_batch, + events: room_events, + }, + state: RoomState { + events: state_events + .iter() + .map(|pdu| pdu.to_sync_state_event()) + .collect(), + }, + ephemeral: Ephemeral { + events: edus, + }, + unread_thread_notifications: BTreeMap::new(), + }) +} diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs new file mode 100644 index 0000000000000000000000000000000000000000..91abd24e996bbed97a6f42c858f4e3dc733dac9c --- /dev/null +++ b/src/api/client/sync/v4.rs @@ -0,0 +1,793 @@ +use std::{ + cmp::{self, Ordering}, + collections::{BTreeMap, BTreeSet, HashSet}, + time::Duration, +}; + +use axum::extract::State; +use conduit::{ + debug, error, extract_variant, + utils::{ + math::{ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + BoolExt, IterStream, ReadyExt, TryFutureExtExt, + }, + warn, Error, PduCount, Result, +}; +use futures::{FutureExt, StreamExt, TryFutureExt}; +use ruma::{ + api::client::{ + error::ErrorKind, + sync::sync_events::{ + self, + v4::{SlidingOp, SlidingSyncRoomHero}, + DeviceLists, UnreadNotificationsCount, + }, + }, + directory::RoomTypeFilter, + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + AnyRawAccountDataEvent, StateEventType, + TimelineEventType::{self, *}, + }, + state_res::Event, + uint, MilliSecondsSinceUnixEpoch, OwnedRoomId, UInt, UserId, +}; +use service::{rooms::read_receipt::pack_receipts, Services}; + +use super::{load_timeline, share_encrypted_room}; +use crate::Ruma; + +const SINGLE_CONNECTION_SYNC: &str = "single_connection_sync"; +const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] = + &[RoomMessage, RoomEncrypted, Sticker, CallInvite, PollStart, Beacon]; + +/// POST `/_matrix/client/unstable/org.matrix.msc3575/sync` +/// +/// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) +pub(crate) async fn sync_events_v4_route( + State(services): State<crate::State>, body: Ruma<sync_events::v4::Request>, +) -> Result<sync_events::v4::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.expect("user is authenticated"); + let mut body = body.body; + // Setup watchers, so if there's no response, we can wait for them + let watcher = services.sync.watch(sender_user, &sender_device); + + let next_batch = services.globals.next_count()?; + + let conn_id = body + .conn_id + .clone() + .unwrap_or_else(|| SINGLE_CONNECTION_SYNC.to_owned()); + + let globalsince = body + .pos + .as_ref() + .and_then(|string| string.parse().ok()) + .unwrap_or(0); + + if globalsince != 0 + && !services + .sync + .remembered(sender_user.clone(), sender_device.clone(), conn_id.clone()) + { + debug!("Restarting sync stream because it was gone from the database"); + return Err(Error::Request( + ErrorKind::UnknownPos, + "Connection data lost since last time".into(), + http::StatusCode::BAD_REQUEST, + )); + } + + if globalsince == 0 { + services + .sync + .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); + } + + // Get sticky parameters from cache + let known_rooms = + services + .sync + .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); + + let all_joined_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .collect() + .await; + + let all_invited_rooms: Vec<_> = services + .rooms + .state_cache + .rooms_invited(sender_user) + .map(|r| r.0) + .collect() + .await; + + let all_rooms = all_joined_rooms + .iter() + .chain(all_invited_rooms.iter()) + .map(Clone::clone) + .collect(); + + if body.extensions.to_device.enabled.unwrap_or(false) { + services + .users + .remove_to_device_events(sender_user, &sender_device, globalsince) + .await; + } + + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in + let mut device_list_changes = HashSet::new(); + let mut device_list_left = HashSet::new(); + + let mut receipts = sync_events::v4::Receipts { + rooms: BTreeMap::new(), + }; + + let mut account_data = sync_events::v4::AccountData { + global: Vec::new(), + rooms: BTreeMap::new(), + }; + if body.extensions.account_data.enabled.unwrap_or(false) { + account_data.global = services + .account_data + .changes_since(None, sender_user, globalsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) + .collect(); + + if let Some(rooms) = body.extensions.account_data.rooms { + for room in rooms { + account_data.rooms.insert( + room.clone(), + services + .account_data + .changes_since(Some(&room), sender_user, globalsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) + .collect(), + ); + } + } + } + + if body.extensions.e2ee.enabled.unwrap_or(false) { + // Look for device list updates of this account + device_list_changes.extend( + services + .users + .keys_changed(sender_user, globalsince, None) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, + ); + + for room_id in &all_joined_rooms { + let Ok(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id).await else { + error!("Room {room_id} has no state"); + continue; + }; + + let since_shortstatehash = services + .rooms + .user + .get_token_shortstatehash(room_id, globalsince) + .await + .ok(); + + let encrypted_room = services + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .await + .is_ok(); + + if let Some(since_shortstatehash) = since_shortstatehash { + // Skip if there are only timeline changes + if since_shortstatehash == current_shortstatehash { + continue; + } + + let since_encryption = services + .rooms + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .await; + + let since_sender_member: Option<RoomMemberEventContent> = services + .rooms + .state_accessor + .state_get_content(since_shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .ok() + .await; + + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + + let new_encrypted_room = encrypted_room && since_encryption.is_err(); + + if encrypted_room { + let current_state_ids = services + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let since_state_ids = services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; + + for (key, id) in current_state_ids { + if since_state_ids.get(&key) != Some(&id) { + let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { + error!("Pdu in state not found: {id}"); + continue; + }; + if pdu.kind == RoomMember { + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + if user_id == *sender_user { + continue; + } + + let content: RoomMemberEventContent = pdu.get_content()?; + match content.membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(&services, sender_user, &user_id, Some(room_id)) + .await + { + device_list_changes.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } + } + if joined_since_last_sync || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_changes.extend( + services + .rooms + .state_cache + .room_members(room_id) + // Don't send key updates from the sender to the sender + .ready_filter(|user_id| sender_user != user_id) + // Only send keys if the sender doesn't share an encrypted room with the target + // already + .filter_map(|user_id| { + share_encrypted_room(&services, sender_user, user_id, Some(room_id)) + .map(|res| res.or_some(user_id.to_owned())) + }) + .collect::<Vec<_>>() + .await, + ); + } + } + } + // Look for device list updates in this room + device_list_changes.extend( + services + .users + .room_keys_changed(room_id, globalsince, None) + .map(|(user_id, _)| user_id) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await, + ); + } + + for user_id in left_encrypted_users { + let dont_share_encrypted_room = !share_encrypted_room(&services, sender_user, &user_id, None).await; + + // If the user doesn't share an encrypted room with the target anymore, we need + // to tell them + if dont_share_encrypted_room { + device_list_left.insert(user_id); + } + } + } + + let mut lists = BTreeMap::new(); + let mut todo_rooms = BTreeMap::new(); // and required state + + for (list_id, list) in &body.lists { + let active_rooms = match list.filters.clone().and_then(|f| f.is_invite) { + Some(true) => &all_invited_rooms, + Some(false) => &all_joined_rooms, + None => &all_rooms, + }; + + let active_rooms = match list.filters.clone().map(|f| f.not_room_types) { + Some(filter) if filter.is_empty() => active_rooms.clone(), + Some(value) => filter_rooms(&services, active_rooms, &value, true).await, + None => active_rooms.clone(), + }; + + let active_rooms = match list.filters.clone().map(|f| f.room_types) { + Some(filter) if filter.is_empty() => active_rooms.clone(), + Some(value) => filter_rooms(&services, &active_rooms, &value, false).await, + None => active_rooms, + }; + + let mut new_known_rooms = BTreeSet::new(); + + let ranges = list.ranges.clone(); + lists.insert( + list_id.clone(), + sync_events::v4::SyncList { + ops: ranges + .into_iter() + .map(|mut r| { + r.0 = r.0.clamp( + uint!(0), + UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), + ); + r.1 = + r.1.clamp(r.0, UInt::try_from(active_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX)); + + let room_ids = if !active_rooms.is_empty() { + active_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec() + } else { + Vec::new() + }; + + new_known_rooms.extend(room_ids.iter().cloned()); + for room_id in &room_ids { + let todo_room = + todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); + + let limit: usize = list + .room_details + .timeline_limit + .map(u64::from) + .map_or(10, usize_from_u64_truncated) + .min(100); + + todo_room + .0 + .extend(list.room_details.required_state.iter().cloned()); + + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = todo_room.2.min( + known_rooms + .get(list_id.as_str()) + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); + } + sync_events::v4::SyncOp { + op: SlidingOp::Sync, + range: Some(r), + index: None, + room_ids, + room_id: None, + } + }) + .collect(), + count: ruma_from_usize(active_rooms.len()), + }, + ); + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + list_id.clone(), + new_known_rooms, + globalsince, + ); + } + } + + let mut known_subscription_rooms = BTreeSet::new(); + for (room_id, room) in &body.room_subscriptions { + if !services.rooms.metadata.exists(room_id).await { + continue; + } + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); + + let limit: usize = room + .timeline_limit + .map(u64::from) + .map_or(10, usize_from_u64_truncated) + .min(100); + + todo_room.0.extend(room.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = todo_room.2.min( + known_rooms + .get("subscriptions") + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); + known_subscription_rooms.insert(room_id.clone()); + } + + for r in body.unsubscribe_rooms { + known_subscription_rooms.remove(&r); + body.room_subscriptions.remove(&r); + } + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + "subscriptions".to_owned(), + known_subscription_rooms, + globalsince, + ); + } + + if let Some(conn_id) = &body.conn_id { + services.sync.update_sync_subscriptions( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + body.room_subscriptions, + ); + } + + let mut rooms = BTreeMap::new(); + for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { + let roomsincecount = PduCount::Normal(*roomsince); + + let mut timestamp: Option<_> = None; + let mut invite_state = None; + let (timeline_pdus, limited); + if all_invited_rooms.contains(room_id) { + // TODO: figure out a timestamp we can use for remote invites + invite_state = services + .rooms + .state_cache + .invite_state(sender_user, room_id) + .await + .ok(); + + (timeline_pdus, limited) = (Vec::new(), true); + } else { + (timeline_pdus, limited) = + match load_timeline(&services, sender_user, room_id, roomsincecount, *timeline_limit).await { + Ok(value) => value, + Err(err) => { + warn!("Encountered missing timeline in {}, error {}", room_id, err); + continue; + }, + }; + } + + account_data.rooms.insert( + room_id.clone(), + services + .account_data + .changes_since(Some(room_id), sender_user, *roomsince) + .await? + .into_iter() + .filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) + .collect(), + ); + + let vector: Vec<_> = services + .rooms + .read_receipt + .readreceipts_since(room_id, *roomsince) + .filter_map(|(read_user, ts, v)| async move { + (!services + .users + .user_is_ignored(&read_user, sender_user) + .await) + .then_some((read_user, ts, v)) + }) + .collect() + .await; + + let receipt_size = vector.len(); + receipts + .rooms + .insert(room_id.clone(), pack_receipts(Box::new(vector.into_iter()))); + + if roomsince != &0 + && timeline_pdus.is_empty() + && account_data.rooms.get(room_id).is_some_and(Vec::is_empty) + && receipt_size == 0 + { + continue; + } + + let prev_batch = timeline_pdus + .first() + .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + }, + PduCount::Normal(c) => c.to_string(), + })) + })? + .or_else(|| { + if roomsince != &0 { + Some(roomsince.to_string()) + } else { + None + } + }); + + let room_events: Vec<_> = timeline_pdus + .iter() + .stream() + .filter_map(|(_, pdu)| async move { + // list of safe and common non-state events to ignore + if matches!( + &pdu.kind, + RoomMessage + | Sticker | CallInvite + | CallNotify | RoomEncrypted + | Image | File | Audio + | Voice | Video | UnstablePollStart + | PollStart | KeyVerificationStart + | Reaction | Emote | Location + ) && services + .users + .user_is_ignored(&pdu.sender, sender_user) + .await + { + return None; + } + + Some(pdu.to_sync_room_event()) + }) + .collect() + .await; + + for (_, pdu) in timeline_pdus { + let ts = MilliSecondsSinceUnixEpoch(pdu.origin_server_ts); + if DEFAULT_BUMP_TYPES.contains(pdu.event_type()) && timestamp.is_none_or(|time| time <= ts) { + timestamp = Some(ts); + } + } + + let required_state = required_state_request + .iter() + .stream() + .filter_map(|state| async move { + services + .rooms + .state_accessor + .room_state_get(room_id, &state.0, &state.1) + .await + .map(|s| s.to_sync_state_event()) + .ok() + }) + .collect() + .await; + + // Heroes + let heroes: Vec<_> = services + .rooms + .state_cache + .room_members(room_id) + .ready_filter(|member| member != sender_user) + .filter_map(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map_ok(|memberevent| SlidingSyncRoomHero { + user_id: user_id.into(), + name: memberevent.displayname, + avatar: memberevent.avatar_url, + }) + .ok() + }) + .take(5) + .collect() + .await; + + let name = match heroes.len().cmp(&(1_usize)) { + Ordering::Greater => { + let firsts = heroes[1..] + .iter() + .map(|h| h.name.clone().unwrap_or_else(|| h.user_id.to_string())) + .collect::<Vec<_>>() + .join(", "); + + let last = heroes[0] + .name + .clone() + .unwrap_or_else(|| heroes[0].user_id.to_string()); + + Some(format!("{firsts} and {last}")) + }, + Ordering::Equal => Some( + heroes[0] + .name + .clone() + .unwrap_or_else(|| heroes[0].user_id.to_string()), + ), + Ordering::Less => None, + }; + + let heroes_avatar = if heroes.len() == 1 { + heroes[0].avatar.clone() + } else { + None + }; + + rooms.insert( + room_id.clone(), + sync_events::v4::SlidingSyncRoom { + name: services + .rooms + .state_accessor + .get_name(room_id) + .await + .ok() + .or(name), + avatar: if let Some(heroes_avatar) = heroes_avatar { + ruma::JsOption::Some(heroes_avatar) + } else { + match services.rooms.state_accessor.get_avatar(room_id).await { + ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), + ruma::JsOption::Null => ruma::JsOption::Null, + ruma::JsOption::Undefined => ruma::JsOption::Undefined, + } + }, + initial: Some(roomsince == &0), + is_dm: None, + invite_state, + unread_notifications: UnreadNotificationsCount { + highlight_count: Some( + services + .rooms + .user + .highlight_count(sender_user, room_id) + .await + .try_into() + .expect("notification count can't go that high"), + ), + notification_count: Some( + services + .rooms + .user + .notification_count(sender_user, room_id) + .await + .try_into() + .expect("notification count can't go that high"), + ), + }, + timeline: room_events, + required_state, + prev_batch, + limited, + joined_count: Some( + services + .rooms + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), + ), + invited_count: Some( + services + .rooms + .state_cache + .room_invited_count(room_id) + .await + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), + ), + num_live: None, // Count events in timeline greater than global sync counter + timestamp, + heroes: Some(heroes), + }, + ); + } + + if rooms + .iter() + .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) + { + // Hang a few seconds so requests are not spammed + // Stop hanging if new info arrives + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; + } + + Ok(sync_events::v4::Response { + initial: globalsince == 0, + txn_id: body.txn_id.clone(), + pos: next_batch.to_string(), + lists, + rooms, + extensions: sync_events::v4::Extensions { + to_device: if body.extensions.to_device.enabled.unwrap_or(false) { + Some(sync_events::v4::ToDevice { + events: services + .users + .get_to_device_events(sender_user, &sender_device) + .collect() + .await, + next_batch: next_batch.to_string(), + }) + } else { + None + }, + e2ee: sync_events::v4::E2EE { + device_lists: DeviceLists { + changed: device_list_changes.into_iter().collect(), + left: device_list_left.into_iter().collect(), + }, + device_one_time_keys_count: services + .users + .count_one_time_keys(sender_user, &sender_device) + .await, + // Fallback keys are not yet supported + device_unused_fallback_key_types: None, + }, + account_data, + receipts, + typing: sync_events::v4::Typing { + rooms: BTreeMap::new(), + }, + }, + delta_token: None, + }) +} + +async fn filter_rooms( + services: &Services, rooms: &[OwnedRoomId], filter: &[RoomTypeFilter], negate: bool, +) -> Vec<OwnedRoomId> { + rooms + .iter() + .stream() + .filter_map(|r| async move { + let room_type = services.rooms.state_accessor.get_room_type(r).await; + + if room_type.as_ref().is_err_and(|e| !e.is_not_found()) { + return None; + } + + let room_type_filter = RoomTypeFilter::from(room_type.ok()); + + let include = if negate { + !filter.contains(&room_type_filter) + } else { + filter.is_empty() || filter.contains(&room_type_filter) + }; + + include.then_some(r.to_owned()) + }) + .collect() + .await +} diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 301568e50b0b79f9919f68d8df4084250ca615fd..b5fa19e3aa8327a11cc69b6073ceb2e27b62d710 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -9,7 +9,7 @@ }, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// @@ -21,32 +21,30 @@ pub(crate) async fn update_tag_route( ) -> Result<create_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let mut tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - - let mut tags_event = event.map_or_else( - || { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); tags_event .content .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(create_tag::v3::Response {}) } @@ -61,29 +59,27 @@ pub(crate) async fn delete_tag_route( ) -> Result<delete_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let mut tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - - let mut tags_event = event.map_or_else( - || { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); tags_event.content.tags.remove(&body.tag.clone().into()); - services.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services + .account_data + .update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + ) + .await?; Ok(delete_tag::v3::Response {}) } @@ -98,20 +94,15 @@ pub(crate) async fn get_tags_route( ) -> Result<get_tags::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services + let tags_event = services .account_data - .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - - let tags_event = event.map_or_else( - || { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - }, - |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), - )?; + .get_room(&body.room_id, sender_user, RoomAccountDataEventType::Tag) + .await + .unwrap_or(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }); Ok(get_tags::v3::Response { tags: tags_event.content.tags, diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 8100f0e67f1838c9c09a750232ccb4c2fac87310..906f779da95c8fb4720422de0e83892844a1be14 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,17 +1,14 @@ use axum::extract::State; -use ruma::{ - api::client::{error::ErrorKind, threads::get_threads}, - uint, -}; +use conduit::{at, PduCount, PduEvent}; +use futures::StreamExt; +use ruma::{api::client::threads::get_threads, uint}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/threads` pub(crate) async fn get_threads_route( - State(services): State<crate::State>, body: Ruma<get_threads::v1::Request>, + State(services): State<crate::State>, ref body: Ruma<get_threads::v1::Request>, ) -> Result<get_threads::v1::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - // Use limit or else 10, with maximum 100 let limit = body .limit @@ -20,35 +17,42 @@ pub(crate) async fn get_threads_route( .unwrap_or(10) .min(100); - let from = if let Some(from) = &body.from { - from.parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? - } else { - u64::MAX - }; + let from: PduCount = body + .from + .as_deref() + .map(str::parse) + .transpose()? + .unwrap_or_else(PduCount::max); - let threads = services + let threads: Vec<(PduCount, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include)? + .threads_until(body.sender_user(), &body.room_id, from, &body.include) + .await? .take(limit) - .filter_map(Result::ok) - .filter(|(_, pdu)| { + .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) + .user_can_see_event(body.sender_user(), &body.room_id, &pdu.event_id) + .await + .then_some((count, pdu)) }) - .collect::<Vec<_>>(); - - let next_batch = threads.last().map(|(count, _)| count.to_string()); + .collect() + .await; Ok(get_threads::v1::Response { + next_batch: threads + .last() + .filter(|_| threads.len() >= limit) + .map(at!(0)) + .as_ref() + .map(ToString::to_string), + chunk: threads .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), - next_batch, }) } diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 1f557ad7b85b37be43c06b7d1516f7e0b12b41ea..2b37a9ec5f47c7f9cec5a7fc9c22615748b9b0d7 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -2,6 +2,7 @@ use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -24,8 +25,9 @@ pub(crate) async fn send_event_to_device_route( // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - .is_some() + .existing_txnid(sender_user, sender_device, &body.txn_id) + .await + .is_ok() { return Ok(send_event_to_device::v3::Response {}); } @@ -53,31 +55,35 @@ pub(crate) async fn send_event_to_device_route( continue; } + let event_type = &body.event_type.to_string(); + + let event = event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?; + match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; + services + .users + .add_to_device_event(sender_user, target_user_id, target_device_id, event_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id?, - &body.event_type.to_string(), - event - .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, - )?; - } + let (event_type, event) = (&event_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + event_type, + event.clone(), + ) + }) + .await; }, } } @@ -86,7 +92,7 @@ pub(crate) async fn send_event_to_device_route( // Save transaction id with empty data services .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + .add_txnid(sender_user, sender_device, &body.txn_id, &[]); Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client/typing.rs b/src/api/client/typing.rs index a06648e05af1591c311673554a8bd8c6dd24ed73..932d221edd31ce255db953d1250220da91925023 100644 --- a/src/api/client/typing.rs +++ b/src/api/client/typing.rs @@ -16,7 +16,8 @@ pub(crate) async fn create_typing_event_route( if !services .rooms .state_cache - .is_joined(sender_user, &body.room_id)? + .is_joined(sender_user, &body.room_id) + .await { return Err(Error::BadRequest(ErrorKind::forbidden(), "You are not in this room.")); } diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index ab4703fdbb576cda2b032c7a1735e028b5602bc0..dc570295c641838058e3f255b16ba29a06294220 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -2,7 +2,8 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{warn, Err}; +use conduit::Err; +use futures::StreamExt; use ruma::{ api::{ client::{ @@ -45,7 +46,7 @@ pub(crate) async fn get_mutual_rooms_route( )); } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { return Ok(mutual_rooms::unstable::Response { joined: vec![], next_batch_token: None, @@ -55,9 +56,10 @@ pub(crate) async fn get_mutual_rooms_route( let mutual_rooms: Vec<OwnedRoomId> = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - .filter_map(Result::ok) - .collect(); + .get_shared_rooms(sender_user, &body.user_id) + .map(ToOwned::to_owned) + .collect() + .await; Ok(mutual_rooms::unstable::Response { joined: mutual_rooms, @@ -99,7 +101,7 @@ pub(crate) async fn get_room_summary( let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; - if !services.rooms.metadata.exists(&room_id)? { + if !services.rooms.metadata.exists(&room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } @@ -108,7 +110,7 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false) + .await { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -122,50 +124,58 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_canonical_alias(&room_id) - .unwrap_or(None), + .await + .ok(), avatar_url: services .rooms .state_accessor - .get_avatar(&room_id)? + .get_avatar(&room_id) + .await .into_option() .unwrap_or_default() .url, - guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, - name: services - .rooms - .state_accessor - .get_name(&room_id) - .unwrap_or(None), + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id).await, + name: services.rooms.state_accessor.get_name(&room_id).await.ok(), num_joined_members: services .rooms .state_cache .room_joined_count(&room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) - .try_into() - .expect("user count should not be that big"), + .await + .unwrap_or(0) + .try_into()?, topic: services .rooms .state_accessor .get_room_topic(&room_id) - .unwrap_or(None), + .await + .ok(), world_readable: services .rooms .state_accessor .is_world_readable(&room_id) - .unwrap_or(false), - join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, - room_type: services.rooms.state_accessor.get_room_type(&room_id)?, - room_version: Some(services.rooms.state.get_room_version(&room_id)?), + .await, + join_rule: services + .rooms + .state_accessor + .get_join_rule(&room_id) + .await + .unwrap_or_default() + .0, + room_type: services + .rooms + .state_accessor + .get_room_type(&room_id) + .await + .ok(), + room_version: services.rooms.state.get_room_version(&room_id).await.ok(), membership: if let Some(sender_user) = sender_user { services .rooms .state_accessor - .get_member(&room_id, sender_user)? - .map_or_else(|| Some(MembershipState::Leave), |content| Some(content.membership)) + .get_member(&room_id, sender_user) + .await + .map_or_else(|_| MembershipState::Leave, |content| content.membership) + .into() } else { None }, @@ -173,7 +183,8 @@ pub(crate) async fn get_room_summary( .rooms .state_accessor .get_room_encryption(&room_id) - .unwrap_or_else(|_e| None), + .await + .ok(), }) } @@ -191,13 +202,14 @@ pub(crate) async fn delete_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services.users.set_timezone(&body.user_id, None).await?; + services.users.set_timezone(&body.user_id, None); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_timezone_key::unstable::Response {}) @@ -217,16 +229,14 @@ pub(crate) async fn set_timezone_key_route( return Err!(Request(Forbidden("You cannot update the profile of another user"))); } - services - .users - .set_timezone(&body.user_id, body.tz.clone()) - .await?; + services.users.set_timezone(&body.user_id, body.tz.clone()); if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_timezone_key::unstable::Response {}) @@ -280,10 +290,11 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), all_joined_rooms).await?; + update_displayname(&services, &body.user_id, Some(profile_key_value.to_string()), &all_joined_rooms).await?; } else if body.key == "avatar_url" { let mxc = ruma::OwnedMxcUri::from(profile_key_value.to_string()); @@ -291,21 +302,23 @@ pub(crate) async fn set_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, Some(mxc), None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, Some(mxc), None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(profile_key_value.clone())); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(set_profile_key::unstable::Response {}) @@ -335,30 +348,33 @@ pub(crate) async fn delete_profile_key_route( .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_displayname(&services, &body.user_id, None, all_joined_rooms).await?; + update_displayname(&services, &body.user_id, None, &all_joined_rooms).await?; } else if body.key == "avatar_url" { let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(&body.user_id) - .filter_map(Result::ok) - .collect(); + .map(Into::into) + .collect() + .await; - update_avatar_url(&services, &body.user_id, None, None, all_joined_rooms).await?; + update_avatar_url(&services, &body.user_id, None, None, &all_joined_rooms).await?; } else { services .users - .set_profile_key(&body.user_id, &body.key, None)?; + .set_profile_key(&body.user_id, &body.key, None); } if services.globals.allow_local_presence() { // Presence update services .presence - .ping_presence(&body.user_id, &PresenceState::Online)?; + .ping_presence(&body.user_id, &PresenceState::Online) + .await?; } Ok(delete_profile_key::unstable::Response {}) @@ -386,26 +402,25 @@ pub(crate) async fn get_timezone_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); return Ok(get_timezone_key::unstable::Response { tz: response.tz, @@ -413,14 +428,14 @@ pub(crate) async fn get_timezone_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_timezone_key::unstable::Response { - tz: services.users.timezone(&body.user_id)?, + tz: services.users.timezone(&body.user_id).await.ok(), }) } @@ -448,32 +463,31 @@ pub(crate) async fn get_profile_key_route( ) .await { - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { services.users.create(&body.user_id, None)?; } services .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; + .set_displayname(&body.user_id, response.displayname.clone()); + services .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; + .set_avatar_url(&body.user_id, response.avatar_url.clone()); + services .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + .set_blurhash(&body.user_id, response.blurhash.clone()); + services .users - .set_timezone(&body.user_id, response.tz.clone()) - .await?; + .set_timezone(&body.user_id, response.tz.clone()); if let Some(value) = response.custom_profile_fields.get(&body.key) { profile_key_value.insert(body.key.clone(), value.clone()); services .users - .set_profile_key(&body.user_id, &body.key, Some(value.clone()))?; + .set_profile_key(&body.user_id, &body.key, Some(value.clone())); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); } @@ -484,13 +498,13 @@ pub(crate) async fn get_profile_key_route( } } - if !services.users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id).await { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation - return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + return Err!(Request(NotFound("Profile was not found."))); } - if let Some(value) = services.users.profile_key(&body.user_id, &body.key)? { + if let Ok(value) = services.users.profile_key(&body.user_id, &body.key).await { profile_key_value.insert(body.key.clone(), value); } else { return Err!(Request(NotFound("The requested profile key does not exist."))); diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index d714fda5457f5f93944cb1040f60b9a7e9f41f32..3aee30c8bcd1c9baa2cc6979a7d1cc83016291c4 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use axum::{extract::State, response::IntoResponse, Json}; +use futures::StreamExt; use ruma::api::client::{ discovery::{ discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, @@ -52,6 +53,7 @@ pub(crate) async fn get_supported_versions_route( ("org.matrix.msc2946".to_owned(), true), /* spaces/hierarchy summaries (https://github.com/matrix-org/matrix-spec-proposals/pull/2946) */ ("org.matrix.msc3026.busy_presence".to_owned(), true), /* busy presence status (https://github.com/matrix-org/matrix-spec-proposals/pull/3026) */ ("org.matrix.msc3827".to_owned(), true), /* filtering of /publicRooms by room type (https://github.com/matrix-org/matrix-spec-proposals/pull/3827) */ + ("org.matrix.msc3952_intentional_mentions".to_owned(), true), /* intentional mentions (https://github.com/matrix-org/matrix-spec-proposals/pull/3952) */ ("org.matrix.msc3575".to_owned(), true), /* sliding sync (https://github.com/matrix-org/matrix-spec-proposals/pull/3575/files#r1588877046) */ ("org.matrix.msc3916.stable".to_owned(), true), /* authenticated media (https://github.com/matrix-org/matrix-spec-proposals/pull/3916) */ ("org.matrix.msc4180".to_owned(), true), /* stable flag for 3916 (https://github.com/matrix-org/matrix-spec-proposals/pull/4180) */ @@ -173,7 +175,7 @@ pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> { /// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// only includes active users (not deactivated, no guests, etc) pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> { - let user_count = services.users.list_local_users()?.len(); + let user_count = services.users.list_local_users().count().await; Ok(Json(serde_json::json!({ "count": user_count diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index 87d4062cd5c5b86c55eeccd108330d79e1acae39..868811a3f46a2a6bad5008358e615f33f3e60551 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,4 +1,6 @@ use axum::extract::State; +use conduit::utils::TryFutureExtExt; +use futures::{pin_mut, StreamExt}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -21,14 +23,12 @@ pub(crate) async fn search_users_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 - let mut users = services.users.iter().filter_map(|user_id| { + let users = services.users.stream().filter_map(|user_id| async { // Filter out buggy users (they should not exist, but you never know...) - let user_id = user_id.ok()?; - let user = search_users::v3::User { - user_id: user_id.clone(), - display_name: services.users.displayname(&user_id).ok()?, - avatar_url: services.users.avatar_url(&user_id).ok()?, + user_id: user_id.to_owned(), + display_name: services.users.displayname(user_id).await.ok(), + avatar_url: services.users.avatar_url(user_id).await.ok(), }; let user_id_matches = user @@ -56,20 +56,15 @@ pub(crate) async fn search_users_route( let user_is_in_public_rooms = services .rooms .state_cache - .rooms_joined(&user_id) - .filter_map(Result::ok) + .rooms_joined(&user.user_id) .any(|room| { services .rooms .state_accessor - .room_state_get(&room, &StateEventType::RoomJoinRules, "") - .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) - }) - }) - }); + .room_state_get_content::<RoomJoinRulesEventContent>(room, &StateEventType::RoomJoinRules, "") + .map_ok_or(false, |content| content.join_rule == JoinRule::Public) + }) + .await; if user_is_in_public_rooms { user_visible = true; @@ -77,25 +72,22 @@ pub(crate) async fn search_users_route( let user_is_in_shared_rooms = services .rooms .user - .get_shared_rooms(vec![sender_user.clone(), user_id]) - .ok()? - .next() - .is_some(); + .has_shared_rooms(sender_user, &user.user_id) + .await; if user_is_in_shared_rooms { user_visible = true; } } - if !user_visible { - return None; - } - - Some(user) + user_visible.then_some(user) }); - let results = users.by_ref().take(limit).collect(); - let limited = users.next().is_some(); + pin_mut!(users); + + let limited = users.by_ref().next().await.is_some(); + + let results = users.take(limit).collect().await; Ok(search_users::v3::Response { results, diff --git a/src/api/mod.rs b/src/api/mod.rs index 82b857db31472f665c216acc1c36693a4fe8849d..fc68af5b526aaf04ae60f0034e6de187de0c59f2 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ -#![recursion_limit = "192"] +#![allow(clippy::toplevel_ref_arg)] pub mod client; pub mod router; @@ -8,7 +8,6 @@ extern crate conduit_service as service; pub(crate) use conduit::{debug_info, pdu::PduEvent, utils, Error, Result}; -pub(crate) use service::services; pub(crate) use self::router::{Ruma, RumaResponse, State}; diff --git a/src/api/router.rs b/src/api/router.rs index 4264e01df55731e1e7acf4673287a1003722d39f..1df4342fed4749d11bc363775dcdce61e82c9885 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -22,101 +22,103 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> { let config = &server.config; let mut router = router - .ruma_route(client::get_timezone_key_route) - .ruma_route(client::get_profile_key_route) - .ruma_route(client::set_profile_key_route) - .ruma_route(client::delete_profile_key_route) - .ruma_route(client::set_timezone_key_route) - .ruma_route(client::delete_timezone_key_route) - .ruma_route(client::appservice_ping) - .ruma_route(client::get_supported_versions_route) - .ruma_route(client::get_register_available_route) - .ruma_route(client::register_route) - .ruma_route(client::get_login_types_route) - .ruma_route(client::login_route) - .ruma_route(client::whoami_route) - .ruma_route(client::logout_route) - .ruma_route(client::logout_all_route) - .ruma_route(client::change_password_route) - .ruma_route(client::deactivate_route) - .ruma_route(client::third_party_route) - .ruma_route(client::request_3pid_management_token_via_email_route) - .ruma_route(client::request_3pid_management_token_via_msisdn_route) - .ruma_route(client::check_registration_token_validity) - .ruma_route(client::get_capabilities_route) - .ruma_route(client::get_pushrules_all_route) - .ruma_route(client::set_pushrule_route) - .ruma_route(client::get_pushrule_route) - .ruma_route(client::set_pushrule_enabled_route) - .ruma_route(client::get_pushrule_enabled_route) - .ruma_route(client::get_pushrule_actions_route) - .ruma_route(client::set_pushrule_actions_route) - .ruma_route(client::delete_pushrule_route) - .ruma_route(client::get_room_event_route) - .ruma_route(client::get_room_aliases_route) - .ruma_route(client::get_filter_route) - .ruma_route(client::create_filter_route) - .ruma_route(client::create_openid_token_route) - .ruma_route(client::set_global_account_data_route) - .ruma_route(client::set_room_account_data_route) - .ruma_route(client::get_global_account_data_route) - .ruma_route(client::get_room_account_data_route) - .ruma_route(client::set_displayname_route) - .ruma_route(client::get_displayname_route) - .ruma_route(client::set_avatar_url_route) - .ruma_route(client::get_avatar_url_route) - .ruma_route(client::get_profile_route) - .ruma_route(client::set_presence_route) - .ruma_route(client::get_presence_route) - .ruma_route(client::upload_keys_route) - .ruma_route(client::get_keys_route) - .ruma_route(client::claim_keys_route) - .ruma_route(client::create_backup_version_route) - .ruma_route(client::update_backup_version_route) - .ruma_route(client::delete_backup_version_route) - .ruma_route(client::get_latest_backup_info_route) - .ruma_route(client::get_backup_info_route) - .ruma_route(client::add_backup_keys_route) - .ruma_route(client::add_backup_keys_for_room_route) - .ruma_route(client::add_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_for_room_route) - .ruma_route(client::delete_backup_keys_for_session_route) - .ruma_route(client::delete_backup_keys_route) - .ruma_route(client::get_backup_keys_for_room_route) - .ruma_route(client::get_backup_keys_for_session_route) - .ruma_route(client::get_backup_keys_route) - .ruma_route(client::set_read_marker_route) - .ruma_route(client::create_receipt_route) - .ruma_route(client::create_typing_event_route) - .ruma_route(client::create_room_route) - .ruma_route(client::redact_event_route) - .ruma_route(client::report_event_route) - .ruma_route(client::create_alias_route) - .ruma_route(client::delete_alias_route) - .ruma_route(client::get_alias_route) - .ruma_route(client::join_room_by_id_route) - .ruma_route(client::join_room_by_id_or_alias_route) - .ruma_route(client::joined_members_route) - .ruma_route(client::leave_room_route) - .ruma_route(client::forget_room_route) - .ruma_route(client::joined_rooms_route) - .ruma_route(client::kick_user_route) - .ruma_route(client::ban_user_route) - .ruma_route(client::unban_user_route) - .ruma_route(client::invite_user_route) - .ruma_route(client::set_room_visibility_route) - .ruma_route(client::get_room_visibility_route) - .ruma_route(client::get_public_rooms_route) - .ruma_route(client::get_public_rooms_filtered_route) - .ruma_route(client::search_users_route) - .ruma_route(client::get_member_events_route) - .ruma_route(client::get_protocols_route) + .ruma_route(&client::get_timezone_key_route) + .ruma_route(&client::get_profile_key_route) + .ruma_route(&client::set_profile_key_route) + .ruma_route(&client::delete_profile_key_route) + .ruma_route(&client::set_timezone_key_route) + .ruma_route(&client::delete_timezone_key_route) + .ruma_route(&client::appservice_ping) + .ruma_route(&client::get_supported_versions_route) + .ruma_route(&client::get_register_available_route) + .ruma_route(&client::register_route) + .ruma_route(&client::get_login_types_route) + .ruma_route(&client::login_route) + .ruma_route(&client::whoami_route) + .ruma_route(&client::logout_route) + .ruma_route(&client::logout_all_route) + .ruma_route(&client::change_password_route) + .ruma_route(&client::deactivate_route) + .ruma_route(&client::third_party_route) + .ruma_route(&client::request_3pid_management_token_via_email_route) + .ruma_route(&client::request_3pid_management_token_via_msisdn_route) + .ruma_route(&client::check_registration_token_validity) + .ruma_route(&client::get_capabilities_route) + .ruma_route(&client::get_pushrules_all_route) + .ruma_route(&client::get_pushrules_global_route) + .ruma_route(&client::set_pushrule_route) + .ruma_route(&client::get_pushrule_route) + .ruma_route(&client::set_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_enabled_route) + .ruma_route(&client::get_pushrule_actions_route) + .ruma_route(&client::set_pushrule_actions_route) + .ruma_route(&client::delete_pushrule_route) + .ruma_route(&client::get_room_event_route) + .ruma_route(&client::get_room_aliases_route) + .ruma_route(&client::get_filter_route) + .ruma_route(&client::create_filter_route) + .ruma_route(&client::create_openid_token_route) + .ruma_route(&client::set_global_account_data_route) + .ruma_route(&client::set_room_account_data_route) + .ruma_route(&client::get_global_account_data_route) + .ruma_route(&client::get_room_account_data_route) + .ruma_route(&client::set_displayname_route) + .ruma_route(&client::get_displayname_route) + .ruma_route(&client::set_avatar_url_route) + .ruma_route(&client::get_avatar_url_route) + .ruma_route(&client::get_profile_route) + .ruma_route(&client::set_presence_route) + .ruma_route(&client::get_presence_route) + .ruma_route(&client::upload_keys_route) + .ruma_route(&client::get_keys_route) + .ruma_route(&client::claim_keys_route) + .ruma_route(&client::create_backup_version_route) + .ruma_route(&client::update_backup_version_route) + .ruma_route(&client::delete_backup_version_route) + .ruma_route(&client::get_latest_backup_info_route) + .ruma_route(&client::get_backup_info_route) + .ruma_route(&client::add_backup_keys_route) + .ruma_route(&client::add_backup_keys_for_room_route) + .ruma_route(&client::add_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_for_room_route) + .ruma_route(&client::delete_backup_keys_for_session_route) + .ruma_route(&client::delete_backup_keys_route) + .ruma_route(&client::get_backup_keys_for_room_route) + .ruma_route(&client::get_backup_keys_for_session_route) + .ruma_route(&client::get_backup_keys_route) + .ruma_route(&client::set_read_marker_route) + .ruma_route(&client::create_receipt_route) + .ruma_route(&client::create_typing_event_route) + .ruma_route(&client::create_room_route) + .ruma_route(&client::redact_event_route) + .ruma_route(&client::report_event_route) + .ruma_route(&client::report_room_route) + .ruma_route(&client::create_alias_route) + .ruma_route(&client::delete_alias_route) + .ruma_route(&client::get_alias_route) + .ruma_route(&client::join_room_by_id_route) + .ruma_route(&client::join_room_by_id_or_alias_route) + .ruma_route(&client::joined_members_route) + .ruma_route(&client::leave_room_route) + .ruma_route(&client::forget_room_route) + .ruma_route(&client::joined_rooms_route) + .ruma_route(&client::kick_user_route) + .ruma_route(&client::ban_user_route) + .ruma_route(&client::unban_user_route) + .ruma_route(&client::invite_user_route) + .ruma_route(&client::set_room_visibility_route) + .ruma_route(&client::get_room_visibility_route) + .ruma_route(&client::get_public_rooms_route) + .ruma_route(&client::get_public_rooms_filtered_route) + .ruma_route(&client::search_users_route) + .ruma_route(&client::get_member_events_route) + .ruma_route(&client::get_protocols_route) .route("/_matrix/client/unstable/thirdparty/protocols", get(client::get_protocols_route_unstable)) - .ruma_route(client::send_message_event_route) - .ruma_route(client::send_state_event_for_key_route) - .ruma_route(client::get_state_events_route) - .ruma_route(client::get_state_events_for_key_route) + .ruma_route(&client::send_message_event_route) + .ruma_route(&client::send_state_event_for_key_route) + .ruma_route(&client::get_state_events_route) + .ruma_route(&client::get_state_events_for_key_route) // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes // share one Ruma request / response type pair with {get,send}_state_event_for_key_route .route( @@ -140,46 +142,46 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> { get(client::get_state_events_for_empty_key_route) .put(client::send_state_event_for_empty_key_route), ) - .ruma_route(client::sync_events_route) - .ruma_route(client::sync_events_v4_route) - .ruma_route(client::get_context_route) - .ruma_route(client::get_message_events_route) - .ruma_route(client::search_events_route) - .ruma_route(client::turn_server_route) - .ruma_route(client::send_event_to_device_route) - .ruma_route(client::create_content_route) - .ruma_route(client::get_content_thumbnail_route) - .ruma_route(client::get_content_route) - .ruma_route(client::get_content_as_filename_route) - .ruma_route(client::get_media_preview_route) - .ruma_route(client::get_media_config_route) - .ruma_route(client::get_devices_route) - .ruma_route(client::get_device_route) - .ruma_route(client::update_device_route) - .ruma_route(client::delete_device_route) - .ruma_route(client::delete_devices_route) - .ruma_route(client::get_tags_route) - .ruma_route(client::update_tag_route) - .ruma_route(client::delete_tag_route) - .ruma_route(client::upload_signing_keys_route) - .ruma_route(client::upload_signatures_route) - .ruma_route(client::get_key_changes_route) - .ruma_route(client::get_pushers_route) - .ruma_route(client::set_pushers_route) - .ruma_route(client::upgrade_room_route) - .ruma_route(client::get_threads_route) - .ruma_route(client::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client::get_relating_events_with_rel_type_route) - .ruma_route(client::get_relating_events_route) - .ruma_route(client::get_hierarchy_route) - .ruma_route(client::get_mutual_rooms_route) - .ruma_route(client::get_room_summary) + .ruma_route(&client::sync_events_route) + .ruma_route(&client::sync_events_v4_route) + .ruma_route(&client::get_context_route) + .ruma_route(&client::get_message_events_route) + .ruma_route(&client::search_events_route) + .ruma_route(&client::turn_server_route) + .ruma_route(&client::send_event_to_device_route) + .ruma_route(&client::create_content_route) + .ruma_route(&client::get_content_thumbnail_route) + .ruma_route(&client::get_content_route) + .ruma_route(&client::get_content_as_filename_route) + .ruma_route(&client::get_media_preview_route) + .ruma_route(&client::get_media_config_route) + .ruma_route(&client::get_devices_route) + .ruma_route(&client::get_device_route) + .ruma_route(&client::update_device_route) + .ruma_route(&client::delete_device_route) + .ruma_route(&client::delete_devices_route) + .ruma_route(&client::get_tags_route) + .ruma_route(&client::update_tag_route) + .ruma_route(&client::delete_tag_route) + .ruma_route(&client::upload_signing_keys_route) + .ruma_route(&client::upload_signatures_route) + .ruma_route(&client::get_key_changes_route) + .ruma_route(&client::get_pushers_route) + .ruma_route(&client::set_pushers_route) + .ruma_route(&client::upgrade_room_route) + .ruma_route(&client::get_threads_route) + .ruma_route(&client::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(&client::get_relating_events_with_rel_type_route) + .ruma_route(&client::get_relating_events_route) + .ruma_route(&client::get_hierarchy_route) + .ruma_route(&client::get_mutual_rooms_route) + .ruma_route(&client::get_room_summary) .route( "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", get(client::get_room_summary_legacy) ) - .ruma_route(client::well_known_support) - .ruma_route(client::well_known_client) + .ruma_route(&client::well_known_support) + .ruma_route(&client::well_known_client) .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) @@ -187,35 +189,35 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> { if config.allow_federation { router = router - .ruma_route(server::get_server_version_route) + .ruma_route(&server::get_server_version_route) .route("/_matrix/key/v2/server", get(server::get_server_keys_route)) .route("/_matrix/key/v2/server/:key_id", get(server::get_server_keys_deprecated_route)) - .ruma_route(server::get_public_rooms_route) - .ruma_route(server::get_public_rooms_filtered_route) - .ruma_route(server::send_transaction_message_route) - .ruma_route(server::get_event_route) - .ruma_route(server::get_backfill_route) - .ruma_route(server::get_missing_events_route) - .ruma_route(server::get_event_authorization_route) - .ruma_route(server::get_room_state_route) - .ruma_route(server::get_room_state_ids_route) - .ruma_route(server::create_leave_event_template_route) - .ruma_route(server::create_leave_event_v1_route) - .ruma_route(server::create_leave_event_v2_route) - .ruma_route(server::create_join_event_template_route) - .ruma_route(server::create_join_event_v1_route) - .ruma_route(server::create_join_event_v2_route) - .ruma_route(server::create_invite_route) - .ruma_route(server::get_devices_route) - .ruma_route(server::get_room_information_route) - .ruma_route(server::get_profile_information_route) - .ruma_route(server::get_keys_route) - .ruma_route(server::claim_keys_route) - .ruma_route(server::get_openid_userinfo_route) - .ruma_route(server::get_hierarchy_route) - .ruma_route(server::well_known_server) - .ruma_route(server::get_content_route) - .ruma_route(server::get_content_thumbnail_route) + .ruma_route(&server::get_public_rooms_route) + .ruma_route(&server::get_public_rooms_filtered_route) + .ruma_route(&server::send_transaction_message_route) + .ruma_route(&server::get_event_route) + .ruma_route(&server::get_backfill_route) + .ruma_route(&server::get_missing_events_route) + .ruma_route(&server::get_event_authorization_route) + .ruma_route(&server::get_room_state_route) + .ruma_route(&server::get_room_state_ids_route) + .ruma_route(&server::create_leave_event_template_route) + .ruma_route(&server::create_leave_event_v1_route) + .ruma_route(&server::create_leave_event_v2_route) + .ruma_route(&server::create_join_event_template_route) + .ruma_route(&server::create_join_event_v1_route) + .ruma_route(&server::create_join_event_v2_route) + .ruma_route(&server::create_invite_route) + .ruma_route(&server::get_devices_route) + .ruma_route(&server::get_room_information_route) + .ruma_route(&server::get_profile_information_route) + .ruma_route(&server::get_keys_route) + .ruma_route(&server::claim_keys_route) + .ruma_route(&server::get_openid_userinfo_route) + .ruma_route(&server::get_hierarchy_route) + .ruma_route(&server::well_known_server) + .ruma_route(&server::get_content_route) + .ruma_route(&server::get_content_thumbnail_route) .route("/_conduwuit/local_user_count", get(client::conduwuit_local_user_count)); } else { router = router @@ -227,11 +229,11 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> { if config.allow_legacy_media { router = router - .ruma_route(client::get_media_config_legacy_route) - .ruma_route(client::get_media_preview_legacy_route) - .ruma_route(client::get_content_legacy_route) - .ruma_route(client::get_content_as_filename_legacy_route) - .ruma_route(client::get_content_thumbnail_legacy_route) + .ruma_route(&client::get_media_config_legacy_route) + .ruma_route(&client::get_media_preview_legacy_route) + .ruma_route(&client::get_content_legacy_route) + .ruma_route(&client::get_content_as_filename_legacy_route) + .ruma_route(&client::get_content_thumbnail_legacy_route) .route("/_matrix/media/v1/config", get(client::get_media_config_legacy_legacy_route)) .route("/_matrix/media/v1/upload", post(client::create_content_legacy_route)) .route( diff --git a/src/api/router/args.rs b/src/api/router/args.rs index a3d09dff56f48a9a797fcc5dd1f0c85a90dd9d44..0b69395695eb55f9faaae2f93ac8265d0fc17764 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -1,9 +1,11 @@ use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; -use bytes::{BufMut, BytesMut}; -use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result}; -use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId}; +use bytes::{BufMut, Bytes, BytesMut}; +use conduit::{debug, err, utils::string::EMPTY, Error, Result}; +use ruma::{ + api::IncomingRequest, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId, +}; use service::Services; use super::{auth, auth::Auth, request, request::Request}; @@ -35,19 +37,57 @@ pub(crate) struct Args<T> { pub(crate) json_body: Option<CanonicalJsonValue>, } +impl<T> Args<T> +where + T: IncomingRequest + Send + Sync + 'static, +{ + #[inline] + pub(crate) fn sender(&self) -> (&UserId, &DeviceId) { (self.sender_user(), self.sender_device()) } + + #[inline] + pub(crate) fn sender_user(&self) -> &UserId { + self.sender_user + .as_deref() + .expect("user must be authenticated for this handler") + } + + #[inline] + pub(crate) fn sender_device(&self) -> &DeviceId { + self.sender_device + .as_deref() + .expect("user must be authenticated and device identified") + } + + #[inline] + pub(crate) fn origin(&self) -> &ServerName { + self.origin + .as_deref() + .expect("server must be authenticated for this handler") + } +} + +impl<T> Deref for Args<T> +where + T: IncomingRequest + Send + Sync + 'static, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { &self.body } +} + #[async_trait] impl<T> FromRequest<State, Body> for Args<T> where - T: IncomingRequest, + T: IncomingRequest + Send + Sync + 'static, { type Rejection = Error; async fn from_request(request: hyper::Request<Body>, services: &State) -> Result<Self, Self::Rejection> { let mut request = request::from(services, request).await?; let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok(); - let auth = auth::auth(services, &mut request, &json_body, &T::METADATA).await?; + let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?; Ok(Self { - body: make_body::<T>(services, &mut request, &mut json_body, &auth)?, + body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?, origin: auth.origin, sender_user: auth.sender_user, sender_device: auth.sender_device, @@ -57,61 +97,65 @@ async fn from_request(request: hyper::Request<Body>, services: &State) -> Result } } -impl<T> Deref for Args<T> { - type Target = T; - - fn deref(&self) -> &Self::Target { &self.body } -} - fn make_body<T>( - services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth, + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, ) -> Result<T> where T: IncomingRequest, { - let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { - let user_id = auth.sender_user.clone().unwrap_or_else(|| { - let server_name = services.globals.server_name(); - UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") - }); - - let uiaa_request = json_body - .get("auth") - .and_then(|auth| auth.as_object()) - .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) - .and_then(|session| { - services.uiaa.get_uiaa_request( - &user_id, - &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), - session, - ) - }); - - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } - } - - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); - buf.into_inner().freeze() - } else { - mem::take(&mut request.body) - }; + let body = take_body(services, request, json_body, auth); + let http_request = into_http_request(request, body); + T::try_from_http_request(http_request, &request.path).map_err(|e| err!(Request(BadJson(debug_warn!("{e}"))))) +} +fn into_http_request(request: &Request, body: Bytes) -> hyper::Request<Bytes> { let mut http_request = hyper::Request::builder() .uri(request.parts.uri.clone()) .method(request.parts.method.clone()); + *http_request.headers_mut().expect("mutable http headers") = request.parts.headers.clone(); + let http_request = http_request.body(body).expect("http request body"); let headers = http_request.headers(); let method = http_request.method(); let uri = http_request.uri(); debug!("{method:?} {uri:?} {headers:?}"); - trace!("{method:?} {uri:?} {json_body:?}"); - T::try_from_http_request(http_request, &request.path).map_err(|e| err!(Request(BadJson(debug_warn!("{e}"))))) + http_request +} + +#[allow(clippy::needless_pass_by_value)] +fn take_body( + services: &Services, request: &mut Request, json_body: Option<&mut CanonicalJsonValue>, auth: &Auth, +) -> Bytes { + let Some(CanonicalJsonValue::Object(json_body)) = json_body else { + return mem::take(&mut request.body); + }; + + let user_id = auth.sender_user.clone().unwrap_or_else(|| { + let server_name = services.globals.server_name(); + UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(CanonicalJsonValue::as_object) + .and_then(|auth| auth.get("session")) + .and_then(CanonicalJsonValue::as_str) + .and_then(|session| { + services + .uiaa + .get_uiaa_request(&user_id, auth.sender_device.as_deref(), session) + }); + + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); + } + } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); + buf.into_inner().freeze() } diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 670f72ba8c9f25d87e297f171200047852c8a780..68abf5e2c3cac3adf8333e1e78f66d84b8493999 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -1,19 +1,28 @@ -use std::collections::BTreeMap; - use axum::RequestPartsExt; use axum_extra::{ headers::{authorization::Bearer, Authorization}, typed_header::TypedHeaderRejectionReason, TypedHeader, }; -use conduit::{debug_info, warn, Err, Error, Result}; -use http::uri::PathAndQuery; +use conduit::{debug_error, err, warn, Err, Error, Result}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, Metadata}, + api::{ + client::{ + directory::get_public_rooms, + error::ErrorKind, + profile::{get_avatar_url, get_display_name, get_profile, get_profile_key, get_timezone_key}, + voip::get_turn_server_info, + }, + federation::openid::get_openid_userinfo, + AuthScheme, IncomingRequest, Metadata, + }, server_util::authorization::XMatrix, - CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, + CanonicalJsonObject, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, +}; +use service::{ + server_keys::{PubKeyMap, PubKeys}, + Services, }; -use service::Services; use super::request::Request; use crate::service::appservice::RegistrationInfo; @@ -33,7 +42,7 @@ pub(super) struct Auth { } pub(super) async fn auth( - services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>, metadata: &Metadata, + services: &Services, request: &mut Request, json_body: Option<&CanonicalJsonValue>, metadata: &Metadata, ) -> Result<Auth> { let bearer: Option<TypedHeader<Authorization<Bearer>>> = request.parts.extract().await?; let token = match &bearer { @@ -44,8 +53,8 @@ pub(super) async fn auth( let token = if let Some(token) = token { if let Some(reg_info) = services.appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { - Token::User((user_id, OwnedDeviceId::from(device_id))) + } else if let Ok((user_id, device_id)) = services.users.find_from_token(token).await { + Token::User((user_id, device_id)) } else { Token::Invalid } @@ -54,9 +63,8 @@ pub(super) async fn auth( }; if metadata.authentication == AuthScheme::None { - match request.parts.uri.path() { - // TODO: can we check this better? - "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { + match metadata { + &get_public_rooms::v3::Request::METADATA => { if !services .globals .config @@ -73,32 +81,29 @@ pub(super) async fn auth( } } }, + &get_profile::v3::Request::METADATA + | &get_profile_key::unstable::Request::METADATA + | &get_display_name::v3::Request::METADATA + | &get_avatar_url::v3::Request::METADATA + | &get_timezone_key::unstable::Request::METADATA => { + if services.globals.config.require_auth_for_profile_requests { + match token { + Token::Appservice(_) | Token::User(_) => { + // we should have validated the token above + // already + }, + Token::None | Token::Invalid => { + return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing or invalid access token.")); + }, + } + } + }, _ => {}, }; } match (metadata.authentication, token) { - (_, Token::Invalid) => { - // OpenID endpoint uses a query param with the same name, drop this once query - // params for user auth are removed from the spec. This is required to make - // integration manager work. - if request.query.access_token.is_some() && request.parts.uri.path().contains("/openid/") { - Ok(Auth { - origin: None, - sender_user: None, - sender_device: None, - appservice_info: None, - }) - } else { - Err(Error::BadRequest( - ErrorKind::UnknownToken { - soft_logout: false, - }, - "Unknown access token.", - )) - } - }, - (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info).await?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { origin: None, @@ -107,9 +112,8 @@ pub(super) async fn auth( appservice_info: Some(*info), }) }, - (AuthScheme::AccessToken, Token::None) => match request.parts.uri.path() { - // TODO: can we check this better? - "/_matrix/client/v3/voip/turnServer" | "/_matrix/client/r0/voip/turnServer" => { + (AuthScheme::AccessToken, Token::None) => match metadata { + &get_turn_server_info::v3::Request::METADATA => { if services.globals.config.turn_allow_guests { Ok(Auth { origin: None, @@ -147,31 +151,54 @@ pub(super) async fn auth( ErrorKind::Unauthorized, "Only appservice access tokens should be used on this endpoint.", )), + (AuthScheme::None, Token::Invalid) => { + // OpenID federation endpoint uses a query param with the same name, drop this + // once query params for user auth are removed from the spec. This is + // required to make integration manager work. + if request.query.access_token.is_some() && metadata == &get_openid_userinfo::v1::Request::METADATA { + Ok(Auth { + origin: None, + sender_user: None, + sender_device: None, + appservice_info: None, + }) + } else { + Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )) + } + }, + (_, Token::Invalid) => Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )), } } -fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { - let user_id = request +async fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { + let user_id_default = + || UserId::parse_with_server_name(info.registration.sender_localpart.as_str(), services.globals.server_name()); + + let Ok(user_id) = request .query .user_id .clone() - .map_or_else( - || { - UserId::parse_with_server_name( - info.registration.sender_localpart.as_str(), - services.globals.server_name(), - ) - }, - UserId::parse, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; + .map_or_else(user_id_default, UserId::parse) + else { + return Err!(Request(InvalidUsername("Username is invalid."))); + }; if !info.is_user_match(&user_id) { - return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); + return Err!(Request(Exclusive("User is not in namespace."))); } - if !services.users.exists(&user_id)? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); + if !services.users.exists(&user_id).await { + return Err!(Request(Forbidden("User does not exist."))); } Ok(Auth { @@ -182,118 +209,115 @@ fn auth_appservice(services: &Services, request: &Request, info: Box<Registratio }) } -async fn auth_server( - services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>, -) -> Result<Auth> { - if !services.server.config.allow_federation { - return Err!(Config("allow_federation", "Federation is disabled.")); - } +async fn auth_server(services: &Services, request: &mut Request, body: Option<&CanonicalJsonValue>) -> Result<Auth> { + type Member = (String, CanonicalJsonValue); + type Object = CanonicalJsonObject; + type Value = CanonicalJsonValue; - let TypedHeader(Authorization(x_matrix)) = request + let x_matrix = parse_x_matrix(request).await?; + auth_server_checks(services, &x_matrix)?; + + let destination = services.globals.server_name(); + let origin = &x_matrix.origin; + let signature_uri = request .parts - .extract::<TypedHeader<Authorization<XMatrix>>>() - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {e}"); + .uri + .path_and_query() + .expect("all requests have a path") + .to_string(); - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => "Missing Authorization header.", - TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", - _ => "Unknown header-related error", - }; + let signature: [Member; 1] = [(x_matrix.key.as_str().into(), Value::String(x_matrix.sig.to_string()))]; - Error::BadRequest(ErrorKind::forbidden(), msg) - })?; + let signatures: [Member; 1] = [(origin.as_str().into(), Value::Object(signature.into()))]; - let origin = &x_matrix.origin; + let authorization: Object = if let Some(body) = body.cloned() { + let authorization: [Member; 6] = [ + ("content".into(), body), + ("destination".into(), Value::String(destination.into())), + ("method".into(), Value::String(request.parts.method.as_str().into())), + ("origin".into(), Value::String(origin.as_str().into())), + ("signatures".into(), Value::Object(signatures.into())), + ("uri".into(), Value::String(signature_uri)), + ]; - if services - .server - .config - .forbidden_remote_server_names - .contains(origin) - { - debug_info!("Refusing to accept inbound federation request to {origin}"); - return Err!(Request(Forbidden("Federation with this homeserver is not allowed."))); - } + authorization.into() + } else { + let authorization: [Member; 5] = [ + ("destination".into(), Value::String(destination.into())), + ("method".into(), Value::String(request.parts.method.as_str().into())), + ("origin".into(), Value::String(origin.as_str().into())), + ("signatures".into(), Value::Object(signatures.into())), + ("uri".into(), Value::String(signature_uri)), + ]; - let signatures = - BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig.to_string()))]); - let signatures = BTreeMap::from_iter([( - origin.as_str().to_owned(), - CanonicalJsonValue::Object( - signatures - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(), - ), - )]); + authorization.into() + }; + + let key = services + .server_keys + .get_verify_key(origin, &x_matrix.key) + .await + .map_err(|e| err!(Request(Forbidden(warn!("Failed to fetch signing keys: {e}")))))?; - let server_destination = services.globals.server_name().as_str().to_owned(); - if let Some(destination) = x_matrix.destination.as_ref() { - if destination != &server_destination { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); + let keys: PubKeys = [(x_matrix.key.to_string(), key.key)].into(); + let keys: PubKeyMap = [(origin.as_str().into(), keys)].into(); + if let Err(e) = ruma::signatures::verify_json(&keys, authorization) { + debug_error!("Failed to verify federation request from {origin}: {e}"); + if request.parts.uri.to_string().contains('@') { + warn!( + "Request uri contained '@' character. Make sure your reverse proxy gives conduwuit the raw uri \ + (apache: use nocanon)" + ); } - } - #[allow(clippy::or_fun_call)] - let signature_uri = CanonicalJsonValue::String( - request - .parts - .uri - .path_and_query() - .unwrap_or(&PathAndQuery::from_static("/")) - .to_string(), - ); + return Err!(Request(Forbidden("Failed to verify X-Matrix signatures."))); + } - let mut request_map = BTreeMap::from_iter([ - ( - "method".to_owned(), - CanonicalJsonValue::String(request.parts.method.to_string()), - ), - ("uri".to_owned(), signature_uri), - ("origin".to_owned(), CanonicalJsonValue::String(origin.as_str().to_owned())), - ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), - ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), - ]); + Ok(Auth { + origin: origin.to_owned().into(), + sender_user: None, + sender_device: None, + appservice_info: None, + }) +} - if let Some(json_body) = json_body { - request_map.insert("content".to_owned(), json_body.clone()); - }; +fn auth_server_checks(services: &Services, x_matrix: &XMatrix) -> Result<()> { + if !services.server.config.allow_federation { + return Err!(Config("allow_federation", "Federation is disabled.")); + } - let keys_result = services - .server_keys - .fetch_signing_keys_for_server(origin, vec![x_matrix.key.to_string()]) - .await; + let destination = services.globals.server_name(); + if x_matrix.destination.as_deref() != Some(destination) { + return Err!(Request(Forbidden("Invalid destination."))); + } - let keys = keys_result.map_err(|e| { - warn!("Failed to fetch signing keys: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Failed to fetch signing keys.") - })?; + let origin = &x_matrix.origin; + if services + .server + .config + .forbidden_remote_server_names + .contains(origin) + { + return Err!(Request(Forbidden(debug_warn!("Federation requests from {origin} denied.")))); + } - let pub_key_map = BTreeMap::from_iter([(origin.as_str().to_owned(), keys)]); + Ok(()) +} - match ruma::signatures::verify_json(&pub_key_map, &request_map) { - Ok(()) => Ok(Auth { - origin: Some(origin.clone()), - sender_user: None, - sender_device: None, - appservice_info: None, - }), - Err(e) => { - warn!("Failed to verify json request from {origin}: {e}\n{request_map:?}"); +async fn parse_x_matrix(request: &mut Request) -> Result<XMatrix> { + let TypedHeader(Authorization(x_matrix)) = request + .parts + .extract::<TypedHeader<Authorization<XMatrix>>>() + .await + .map_err(|e| { + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => "Missing Authorization header.", + TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", + _ => "Unknown header-related error", + }; - if request.parts.uri.to_string().contains('@') { - warn!( - "Request uri contained '@' character. Make sure your reverse proxy gives Conduit the raw uri \ - (apache: use nocanon)" - ); - } + err!(Request(Forbidden(warn!("{msg}: {e}")))) + })?; - Err(Error::BadRequest( - ErrorKind::forbidden(), - "Failed to verify X-Matrix signatures.", - )) - }, - } + Ok(x_matrix) } diff --git a/src/api/router/handler.rs b/src/api/router/handler.rs index d112ec58746ea56dc3c75d04a121cc3f7050253e..0022f06a9c1a7f7bba0b52f750ccc3fa76cdb49f 100644 --- a/src/api/router/handler.rs +++ b/src/api/router/handler.rs @@ -1,5 +1,3 @@ -use std::future::Future; - use axum::{ extract::FromRequestParts, response::IntoResponse, @@ -7,19 +5,25 @@ Router, }; use conduit::Result; +use futures::{Future, TryFutureExt}; use http::Method; use ruma::api::IncomingRequest; use super::{Ruma, RumaResponse, State}; +pub(in super::super) trait RumaHandler<T> { + fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State>; + fn add_routes(&'static self, router: Router<State>) -> Router<State>; +} + pub(in super::super) trait RouterExt { - fn ruma_route<H, T>(self, handler: H) -> Self + fn ruma_route<H, T>(self, handler: &'static H) -> Self where H: RumaHandler<T>; } impl RouterExt for Router<State> { - fn ruma_route<H, T>(self, handler: H) -> Self + fn ruma_route<H, T>(self, handler: &'static H) -> Self where H: RumaHandler<T>, { @@ -27,34 +31,28 @@ fn ruma_route<H, T>(self, handler: H) -> Self } } -pub(in super::super) trait RumaHandler<T> { - fn add_routes(&self, router: Router<State>) -> Router<State>; - - fn add_route(&self, router: Router<State>, path: &str) -> Router<State>; -} - macro_rules! ruma_handler { ( $($tx:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl<Req, Ret, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun + impl<Err, Req, Fut, Fun, $($tx,)*> RumaHandler<($($tx,)* Ruma<Req>,)> for Fun where - Req: IncomingRequest + Send + 'static, - Ret: IntoResponse, - Fut: Future<Output = Result<Req::OutgoingResponse, Ret>> + Send, - Fun: FnOnce($($tx,)* Ruma<Req>,) -> Fut + Clone + Send + Sync + 'static, - $( $tx: FromRequestParts<State> + Send + 'static, )* + Fun: Fn($($tx,)* Ruma<Req>,) -> Fut + Send + Sync + 'static, + Fut: Future<Output = Result<Req::OutgoingResponse, Err>> + Send, + Req: IncomingRequest + Send + Sync + 'static, + Err: IntoResponse + Send, + <Req as IncomingRequest>::OutgoingResponse: Send, + $( $tx: FromRequestParts<State> + Send + Sync + 'static, )* { - fn add_routes(&self, router: Router<State>) -> Router<State> { + fn add_routes(&'static self, router: Router<State>) -> Router<State> { Req::METADATA .history .all_paths() .fold(router, |router, path| self.add_route(router, path)) } - fn add_route(&self, router: Router<State>, path: &str) -> Router<State> { - let handle = self.clone(); + fn add_route(&'static self, router: Router<State>, path: &str) -> Router<State> { + let action = |$($tx,)* req| self($($tx,)* req).map_ok(RumaResponse); let method = method_to_filter(&Req::METADATA.method); - let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) }; router.route(path, on(method, action)) } } diff --git a/src/api/router/response.rs b/src/api/router/response.rs index 2aaa79faa86d6238dbbcece5a580e85334fd79da..70bbb93644f19455cd51caeaae8ba50135f57d62 100644 --- a/src/api/router/response.rs +++ b/src/api/router/response.rs @@ -5,13 +5,18 @@ use http_body_util::Full; use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; -pub(crate) struct RumaResponse<T>(pub(crate) T); +pub(crate) struct RumaResponse<T>(pub(crate) T) +where + T: OutgoingResponse; impl From<Error> for RumaResponse<UiaaResponse> { fn from(t: Error) -> Self { Self(t.into()) } } -impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { +impl<T> IntoResponse for RumaResponse<T> +where + T: OutgoingResponse, +{ fn into_response(self) -> Response { self.0 .try_into_http_response::<BytesMut>() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 1b665c19d3bfa54f4f982597b032bc898597afaa..b0bd48e807c9f266f4a2cd0c5f71832236a8f7f2 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,10 +1,14 @@ +use std::cmp; + use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::backfill::get_backfill}, - uint, user_id, MilliSecondsSinceUnixEpoch, +use conduit::{ + utils::{IterStream, ReadyExt}, + PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{api::federation::backfill::get_backfill, uint, MilliSecondsSinceUnixEpoch}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/backfill/<room_id>` @@ -12,34 +16,16 @@ /// Retrieves events from before the sender joined the room, if the room's /// history visibility allows. pub(crate) async fn get_backfill_route( - State(services): State<crate::State>, body: Ruma<get_backfill::v1::Request>, + State(services): State<crate::State>, ref body: Ruma<get_backfill::v1::Request>, ) -> Result<get_backfill::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - services - .rooms - .event_handler - .acl_check(origin, &body.room_id)?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } - - let until = body - .v - .iter() - .map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) - .filter_map(|r| r.ok().flatten()) - .max() - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; + .check() + .await?; let limit = body .limit @@ -47,31 +33,49 @@ pub(crate) async fn get_backfill_route( .try_into() .expect("UInt could not be converted to usize"); - let all_events = services - .rooms - .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? - .take(limit); - - let events = all_events - .filter_map(Result::ok) - .filter(|(_, e)| { - matches!( - services - .rooms - .state_accessor - .server_can_see_event(origin, &e.room_id, &e.event_id,), - Ok(true), - ) + let from = body + .v + .iter() + .stream() + .filter_map(|event_id| { + services + .rooms + .timeline + .get_pdu_count(event_id) + .map(Result::ok) }) - .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) - .filter_map(|r| r.ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(); + .ready_fold(PduCount::min(), cmp::max) + .await; Ok(get_backfill::v1::Response { - origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus: events, + + origin: services.globals.server_name().to_owned(), + + pdus: services + .rooms + .timeline + .pdus_rev(None, &body.room_id, Some(from.saturating_add(1))) + .await? + .take(limit) + .filter_map(|(_, pdu)| async move { + services + .rooms + .state_accessor + .server_can_see_event(body.origin(), &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) + }) + .filter_map(|pdu| async move { + services + .rooms + .timeline + .get_pdu_json(&pdu.event_id) + .await + .ok() + }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await, }) } diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e11a01a20ab600ed9f1595cd9f544765a94381f7..29d5d87037606da2b762cf6dacb3087b05f946d3 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,10 +1,8 @@ use axum::extract::State; -use conduit::{Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation::event::get_event}, - MilliSecondsSinceUnixEpoch, RoomId, -}; +use conduit::{err, Result}; +use ruma::{api::federation::event::get_event, MilliSecondsSinceUnixEpoch, RoomId}; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event/{eventId}` @@ -16,39 +14,35 @@ pub(crate) async fn get_event_route( State(services): State<crate::State>, body: Ruma<get_event::v1::Request>, ) -> Result<get_event::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Event not found."))))?; - let room_id_str = event + let room_id: &RoomId = event .get("room_id") .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database."))?; - - let room_id = - <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - - if !services.rooms.state_accessor.is_world_readable(room_id)? - && !services.rooms.state_cache.server_in_room(origin, room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); - } - - if !services - .rooms - .state_accessor - .server_can_see_event(origin, room_id, &body.event_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not allowed to see event.")); + .ok_or_else(|| err!(Database("Invalid event in database.")))? + .try_into() + .map_err(|_| err!(Database("Invalid room_id in event in database.")))?; + + AccessCheck { + services: &services, + origin: body.origin(), + room_id, + event_id: Some(&body.event_id), } + .check() + .await?; Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: services.sending.convert_to_outgoing_federation_event(event), + pdu: services + .sending + .convert_to_outgoing_federation_event(event) + .await, }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 4b0f6bc0010d7b0d7e76b2a10be5b2dd0340fc4c..faeb2b99704643b258b7c361c118dce26360ca00 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,12 +1,14 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{Error, Result}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` @@ -17,30 +19,21 @@ pub(crate) async fn get_event_authorization_route( State(services): State<crate::State>, body: Ruma<get_event_authorization::v1::Request>, ) -> Result<get_event_authorization::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - services - .rooms - .event_handler - .acl_check(origin, &body.room_id)?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let event = services .rooms .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + .get_pdu_json(&body.event_id) + .await + .map_err(|_| Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; let room_id_str = event .get("room_id") @@ -50,16 +43,17 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .event_ids_iter(room_id, &[body.event_id.borrow()]) + .await? + .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) + .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) + .collect() + .await; Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index e2c3c93cf54cb6cf1bf3efb8c9299de45ff6cbed..7dff44dcce7408edf08b4257de58022315c3f3ae 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -2,9 +2,10 @@ use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_missing_events}, - OwnedEventId, RoomId, + CanonicalJsonValue, EventId, RoomId, }; +use super::AccessCheck; use crate::Ruma; /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` @@ -13,29 +14,16 @@ pub(crate) async fn get_missing_events_route( State(services): State<crate::State>, body: Ruma<get_missing_events::v1::Request>, ) -> Result<get_missing_events::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - services - .rooms - .event_handler - .acl_check(origin, &body.room_id)?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; - let limit = body - .limit - .try_into() - .expect("UInt could not be converted to usize"); + let limit = body.limit.try_into()?; let mut queued_events = body.latest_events.clone(); // the vec will never have more entries the limit @@ -43,7 +31,12 @@ pub(crate) async fn get_missing_events_route( let mut i: usize = 0; while i < queued_events.len() && events.len() < limit { - if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { + if let Ok(pdu) = services + .rooms + .timeline + .get_pdu_json(&queued_events[i]) + .await + { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -64,24 +57,32 @@ pub(crate) async fn get_missing_events_route( if !services .rooms .state_accessor - .server_can_see_event(origin, &body.room_id, &queued_events[i])? + .server_can_see_event(body.origin(), &body.room_id, &queued_events[i]) + .await { i = i.saturating_add(1); continue; } - queued_events.extend_from_slice( - &serde_json::from_value::<Vec<OwnedEventId>>( - serde_json::to_value( - pdu.get("prev_events") - .cloned() - .ok_or_else(|| Error::bad_database("Event in db has no prev_events property."))?, - ) - .expect("canonical json is valid json value"), - ) - .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, + let prev_events = pdu + .get("prev_events") + .and_then(CanonicalJsonValue::as_array) + .unwrap_or_default(); + + queued_events.extend( + prev_events + .iter() + .map(<&EventId>::try_from) + .filter_map(Result::ok) + .map(ToOwned::to_owned), + ); + + events.push( + services + .sending + .convert_to_outgoing_federation_event(pdu) + .await, ); - events.push(services.sending.convert_to_outgoing_federation_event(pdu)); } i = i.saturating_add(1); } diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 530ed1456e4df039e9448e06c48192f064184ce6..e3ce71084e738084711f506a230cb69bc7dc1521 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -10,13 +10,11 @@ pub(crate) async fn get_hierarchy_route( State(services): State<crate::State>, body: Ruma<get_hierarchy::v1::Request>, ) -> Result<get_hierarchy::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - if services.rooms.metadata.exists(&body.room_id)? { + if services.rooms.metadata.exists(&body.room_id).await { services .rooms .spaces - .get_federation_hierarchy(&body.room_id, origin, body.suggested_only) + .get_federation_hierarchy(&body.room_id, body.origin(), body.suggested_only) .await } else { Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 688e026c534075a5453c4c281d66a66fbbd59e8e..0ceb914fc91d8de1810519504f2584a1ac9ba061 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,11 +1,12 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{utils, warn, Error, PduEvent, Result}; +use base64::{engine::general_purpose, Engine as _}; +use conduit::{err, utils, utils::hash::sha256, warn, Err, Error, PduEvent, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, events::room::member::{MembershipState, RoomMemberEventContent}, serde::JsonObject, - CanonicalJsonValue, EventId, OwnedUserId, + CanonicalJsonValue, EventId, OwnedUserId, UserId, }; use crate::Ruma; @@ -18,13 +19,12 @@ pub(crate) async fn create_invite_route( State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<create_invite::v2::Request>, ) -> Result<create_invite::v2::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - // ACL check origin services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(body.origin(), &body.room_id) + .await?; if !services .globals @@ -46,10 +46,7 @@ pub(crate) async fn create_invite_route( .forbidden_remote_server_names .contains(&server.to_owned()) { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } } @@ -57,56 +54,41 @@ pub(crate) async fn create_invite_route( .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Received federated/remote invite from banned server {origin} for room ID {}. Rejecting.", + "Received federated/remote invite from banned server {} for room ID {}. Rejecting.", + body.origin(), body.room_id ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server is banned on this homeserver.", - )); - } - if let Some(via) = &body.via { - if via.is_empty() { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "via field must not be empty.")); - } + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } let mut signed_event = utils::to_canonical_object(&body.event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; - let invited_user: OwnedUserId = serde_json::from_value( - signed_event - .get("state_key") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event has no state_key property."))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user ID."))?; + let invited_user: OwnedUserId = signed_event + .get("state_key") + .try_into() + .map(UserId::to_owned) + .map_err(|e| err!(Request(InvalidParam("Invalid state_key property: {e}"))))?; if !services.globals.server_is_ours(invited_user.server_name()) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this homeserver.", - )); + return Err!(Request(InvalidParam("User does not belong to this homeserver."))); } // Make sure we're not ACL'ed from their room. services .rooms .event_handler - .acl_check(invited_user.server_name(), &body.room_id)?; + .acl_check(invited_user.server_name(), &body.room_id) + .await?; - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut signed_event, - &body.room_version, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + services + .server_keys + .hash_and_sign_event(&mut signed_event, &body.room_version) + .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; // Generate event id let event_id = EventId::parse(format!( @@ -119,27 +101,17 @@ pub(crate) async fn create_invite_route( // Add event_id back signed_event.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.to_string())); - let sender: OwnedUserId = serde_json::from_value( - signed_event - .get("sender") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event had no sender property."))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; + let sender: &UserId = signed_event + .get("sender") + .try_into() + .map_err(|e| err!(Request(InvalidParam("Invalid sender property: {e}"))))?; - if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This room is banned on this homeserver.", - )); + if services.rooms.metadata.is_banned(&body.room_id).await && !services.users.is_admin(&invited_user).await { + return Err!(Request(Forbidden("This room is banned on this homeserver."))); } - if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "This server does not allow room invites.", - )); + if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user).await { + return Err!(Request(Forbidden("This server does not allow room invites."))); } let mut invite_state = body.invite_room_state.clone(); @@ -154,27 +126,54 @@ pub(crate) async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); - // If we are active in the room, the remote server will notify us about the join - // via /send + // If we are active in the room, the remote server will notify us about the + // join/invite through /send. If we are not in the room, we need to manually + // record the invited state for client /sync through update_membership(), and + // send the invite PDU to the relevant appservices. if !services .rooms .state_cache - .server_in_room(services.globals.server_name(), &body.room_id)? + .server_in_room(services.globals.server_name(), &body.room_id) + .await { - services.rooms.state_cache.update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - body.via.clone(), - true, - )?; + services + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + sender, + Some(invite_state), + body.via.clone(), + true, + ) + .await?; + } + + for appservice in services.appservice.read().await.values() { + if appservice.is_user_match(&invited_user) { + services + .sending + .send_appservice_request( + appservice.registration.clone(), + ruma::api::appservice::event::push_events::v1::Request { + events: vec![pdu.to_room_event()], + txn_id: general_purpose::URL_SAFE_NO_PAD + .encode(sha256::hash(pdu.event_id.as_bytes())) + .into(), + ephemeral: Vec::new(), + to_device: Vec::new(), + }, + ) + .await?; + } } Ok(create_invite::v2::Response { event: services .sending - .convert_to_outgoing_federation_event(signed_event), + .convert_to_outgoing_federation_event(signed_event) + .await, }) } diff --git a/src/api/server/key.rs b/src/api/server/key.rs index 686e44242c40a10c2a35080b09e2d2964fc62c9c..37fffa9fbf0ec729d2e6b89c1b93f273aabbe280 100644 --- a/src/api/server/key.rs +++ b/src/api/server/key.rs @@ -1,20 +1,19 @@ use std::{ - collections::BTreeMap, + mem::take, time::{Duration, SystemTime}, }; use axum::{extract::State, response::IntoResponse, Json}; +use conduit::{utils::timepoint_from_now, Result}; use ruma::{ api::{ - federation::discovery::{get_server_keys, ServerSigningKeys, VerifyKey}, + federation::discovery::{get_server_keys, OldVerifyKey, ServerSigningKeys}, OutgoingResponse, }, - serde::{Base64, Raw}, - MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, + serde::Raw, + MilliSecondsSinceUnixEpoch, Signatures, }; -use crate::Result; - /// # `GET /_matrix/key/v2/server` /// /// Gets the public signing keys of this server. @@ -24,47 +23,49 @@ // Response type for this endpoint is Json because we need to calculate a // signature for the response pub(crate) async fn get_server_keys_route(State(services): State<crate::State>) -> Result<impl IntoResponse> { - let verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> = BTreeMap::from([( - format!("ed25519:{}", services.globals.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(services.globals.keypair().public_key().to_vec()), - }, - )]); + let server_name = services.globals.server_name(); + let active_key_id = services.server_keys.active_key_id(); + let mut all_keys = services.server_keys.verify_keys_for(server_name).await; + + let verify_keys = all_keys + .remove_entry(active_key_id) + .expect("active verify_key is missing"); + + let old_verify_keys = all_keys + .into_iter() + .map(|(id, key)| (id, OldVerifyKey::new(expires_ts(), key.key))) + .collect(); + + let server_key = ServerSigningKeys { + verify_keys: [verify_keys].into(), + old_verify_keys, + server_name: server_name.to_owned(), + valid_until_ts: valid_until_ts(), + signatures: Signatures::new(), + }; - let mut response = serde_json::from_slice( - get_server_keys::v2::Response { - server_key: Raw::new(&ServerSigningKeys { - server_name: services.globals.server_name().to_owned(), - verify_keys, - old_verify_keys: BTreeMap::new(), - signatures: BTreeMap::new(), - valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(86400 * 7)) - .expect("valid_until_ts should not get this high"), - ) - .expect("time is valid"), - }) - .expect("static conversion, no errors"), - } + let server_key = Raw::new(&server_key)?; + let mut response = get_server_keys::v2::Response::new(server_key) .try_into_http_response::<Vec<u8>>() - .unwrap() - .body(), - ) - .unwrap(); + .map(|mut response| take(response.body_mut())) + .and_then(|body| serde_json::from_slice(&body).map_err(Into::into))?; - ruma::signatures::sign_json( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut response, - ) - .unwrap(); + services.server_keys.sign_json(&mut response)?; Ok(Json(response)) } +fn valid_until_ts() -> MilliSecondsSinceUnixEpoch { + let dur = Duration::from_secs(86400 * 7); + let timepoint = timepoint_from_now(dur).expect("SystemTime should not overflow"); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} + +fn expires_ts() -> MilliSecondsSinceUnixEpoch { + let timepoint = SystemTime::now(); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} + /// # `GET /_matrix/key/v2/server/{keyId}` /// /// Gets the public signing keys of this server. diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index 021016be22bfe38683cb34f350f00b224ec29141..d5ea675e98146305ed3623ae250c2bb88297809f 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,4 +1,9 @@ use axum::extract::State; +use conduit::{ + utils::{IterStream, ReadyExt}, + warn, +}; +use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, events::{ @@ -6,12 +11,11 @@ join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, member::{MembershipState, RoomMemberEventContent}, }, - StateEventType, TimelineEventType, + StateEventType, }, CanonicalJsonObject, RoomId, RoomVersionId, UserId, }; use serde_json::value::to_raw_value; -use tracing::warn; use crate::{ service::{pdu::PduBuilder, Services}, @@ -24,12 +28,11 @@ pub(crate) async fn create_join_event_template_route( State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>, ) -> Result<prepare_join_event::v1::Response> { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - let origin = body.origin.as_ref().expect("server is authenticated"); - if body.user_id.server_name() != origin { + if body.user_id.server_name() != body.origin() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Not allowed to join on behalf of another server/user", @@ -40,18 +43,21 @@ pub(crate) async fn create_join_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(body.origin(), &body.room_id) + .await?; if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Server {origin} for remote user {} tried joining room ID {} which has a server name that is globally \ + "Server {} for remote user {} tried joining room ID {} which has a server name that is globally \ forbidden. Rejecting.", - &body.user_id, &body.room_id, + body.origin(), + &body.user_id, + &body.room_id, ); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -73,7 +79,15 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version not supported.", + )); + } let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; @@ -81,22 +95,24 @@ pub(crate) async fn create_join_event_template_route( .rooms .state_cache .is_left(&body.user_id, &body.room_id) - .unwrap_or(true)) - && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? + .await) + && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id).await? { let auth_user = services .rooms .state_cache .room_members(&body.room_id) - .filter_map(Result::ok) - .filter(|user| user.server_name() == services.globals.server_name()) - .find(|user| { + .ready_filter(|user| user.server_name() == services.globals.server_name()) + .filter(|user| { services .rooms .state_accessor .user_can_invite(&body.room_id, user, &body.user_id, &state_lock) - .unwrap_or(false) - }); + }) + .boxed() + .next() + .await + .map(ToOwned::to_owned); if auth_user.is_some() { auth_user @@ -110,41 +126,22 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; - if !body.ver.contains(&room_version_id) { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, - "Room version not supported.", - )); - } - - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Join, - third_party_invite: None, - reason: None, - join_authorized_via_users_server, - }) - .expect("member event is valid value"); - - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder::state( + body.user_id.to_string(), + &RoomMemberEventContent { + join_authorized_via_users_server, + ..RoomMemberEventContent::new(MembershipState::Join) + }, + ), + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); @@ -161,7 +158,7 @@ pub(crate) async fn create_join_event_template_route( /// This doesn't check the current user's membership. This should be done /// externally, either by using the state cache or attempting to authorize the /// event. -pub(crate) fn user_can_perform_restricted_join( +pub(crate) async fn user_can_perform_restricted_join( services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result<bool> { use RoomVersionId::*; @@ -169,18 +166,15 @@ pub(crate) fn user_can_perform_restricted_join( let join_rules_event = services .rooms .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let Some(join_rules_event_content) = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event in database: {e}"); - Error::bad_database("Invalid join rules event in database") - }) + .room_state_get(room_id, &StateEventType::RoomJoinRules, "") + .await; + + let Ok(Ok(join_rules_event_content)) = join_rules_event.as_ref().map(|join_rules_event| { + serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event in database: {e}"); + Error::bad_database("Invalid join rules event in database") }) - .transpose()? - else { + }) else { return Ok(false); }; @@ -201,13 +195,10 @@ pub(crate) fn user_can_perform_restricted_join( None } }) - .any(|m| { - services - .rooms - .state_cache - .is_joined(user_id, &m.room_id) - .unwrap_or(false) - }) { + .stream() + .any(|m| services.rooms.state_cache.is_joined(user_id, &m.room_id)) + .await + { Ok(true) } else { Err(Error::BadRequest( diff --git a/src/api/server/make_knock.rs b/src/api/server/make_knock.rs new file mode 100644 index 0000000000000000000000000000000000000000..c1875a1f81683f0e3f28de85cfe4f045fdcb8d04 --- /dev/null +++ b/src/api/server/make_knock.rs @@ -0,0 +1,107 @@ +use axum::extract::State; +use conduit::Err; +use ruma::{ + api::{client::error::ErrorKind, federation::knock::create_knock_event_template}, + events::room::member::{MembershipState, RoomMemberEventContent}, + RoomVersionId, +}; +use serde_json::value::to_raw_value; +use tracing::warn; +use RoomVersionId::*; + +use crate::{service::pdu::PduBuilder, Error, Result, Ruma}; + +/// # `GET /_matrix/federation/v1/make_knock/{roomId}/{userId}` +/// +/// Creates a knock template. +pub(crate) async fn create_knock_event_template_route( + State(services): State<crate::State>, body: Ruma<create_knock_event_template::v1::Request>, +) -> Result<create_knock_event_template::v1::Response> { + if !services.rooms.metadata.exists(&body.room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } + + if body.user_id.server_name() != body.origin() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to knock on behalf of another server/user", + )); + } + + // ACL check origin server + services + .rooms + .event_handler + .acl_check(body.origin(), &body.room_id) + .await?; + + if services + .globals + .config + .forbidden_remote_server_names + .contains(body.origin()) + { + warn!( + "Server {} for remote user {} tried knocking room ID {} which has a server name that is globally \ + forbidden. Rejecting.", + body.origin(), + &body.user_id, + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + + if let Some(server) = body.room_id.server_name() { + if services + .globals + .config + .forbidden_remote_server_names + .contains(&server.to_owned()) + { + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + } + + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version does not support knocking.", + )); + } + + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Your homeserver does not support the features required to knock on this room.", + )); + } + + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; + + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Knock)), + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; + + drop(state_lock); + + // room v3 and above removed the "event_id" field from remote PDU format + super::maybe_strip_event_id(&mut pdu_json, &room_version_id)?; + + Ok(create_knock_event_template::v1::Response { + room_version: room_version_id, + event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), + }) +} diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 3eb0d77ab7ce59a204e84824bfa78bca5e82da79..33a945603f3e4cbd970b5211664cbed23d6d7900 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -2,10 +2,7 @@ use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, - events::{ - room::member::{MembershipState, RoomMemberEventContent}, - TimelineEventType, - }, + events::room::member::{MembershipState, RoomMemberEventContent}, }; use serde_json::value::to_raw_value; @@ -18,12 +15,11 @@ pub(crate) async fn create_leave_event_template_route( State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>, ) -> Result<prepare_leave_event::v1::Response> { - if !services.rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - let origin = body.origin.as_ref().expect("server is authenticated"); - if body.user_id.server_name() != origin { + if body.user_id.server_name() != body.origin() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Not allowed to leave on behalf of another server/user", @@ -34,35 +30,22 @@ pub(crate) async fn create_leave_event_template_route( services .rooms .event_handler - .acl_check(origin, &body.room_id)?; + .acl_check(body.origin(), &body.room_id) + .await?; - let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Leave, - third_party_invite: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - timestamp: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services + .rooms + .timeline + .create_hash_and_sign_event( + PduBuilder::state(body.user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Leave)), + &body.user_id, + &body.room_id, + &state_lock, + ) + .await?; drop(state_lock); diff --git a/src/api/server/mod.rs b/src/api/server/mod.rs index 9a184f237b2b3b6c3401d5e825df6b7eb036dcf0..9b7d91cba688e40063feb6ccf4925765998440d3 100644 --- a/src/api/server/mod.rs +++ b/src/api/server/mod.rs @@ -41,3 +41,6 @@ pub(super) use user::*; pub(super) use version::*; pub(super) use well_known::*; + +mod utils; +use utils::AccessCheck; diff --git a/src/api/server/openid.rs b/src/api/server/openid.rs index 6a1b99b755e3429910444613722aa7f4c057d3c8..9b54807a691374a6281b7be1346c17ad881a9326 100644 --- a/src/api/server/openid.rs +++ b/src/api/server/openid.rs @@ -10,6 +10,9 @@ pub(crate) async fn get_openid_userinfo_route( State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>, ) -> Result<get_openid_userinfo::v1::Response> { Ok(get_openid_userinfo::v1::Response::new( - services.users.find_from_openid_token(&body.access_token)?, + services + .users + .find_from_openid_token(&body.access_token) + .await?, )) } diff --git a/src/api/server/publicrooms.rs b/src/api/server/publicrooms.rs index af8a58464d0c6463640c8d35ba652ce0ae070bae..f6c418592a0a60cb30290d70fb862a7bd48a8d48 100644 --- a/src/api/server/publicrooms.rs +++ b/src/api/server/publicrooms.rs @@ -20,7 +20,8 @@ pub(crate) async fn get_public_rooms_filtered_route( ) -> Result<get_public_rooms_filtered::v1::Response> { if !services .globals - .allow_public_room_directory_over_federation() + .config + .allow_public_room_directory_over_federation { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room directory is not public")); } diff --git a/src/api/server/query.rs b/src/api/server/query.rs index c2b78bded157bcee9e54725aabe666d4c35f4898..bf515b3c7f9661400b8e8bed3c424f27572646fe 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{err, Error, Result}; +use futures::StreamExt; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -23,15 +24,17 @@ pub(crate) async fn get_room_information_route( let room_id = services .rooms .alias - .resolve_local_alias(&body.room_alias)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; + .resolve_local_alias(&body.room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room alias not found."))))?; let mut servers: Vec<OwnedServerName> = services .rooms .state_cache .room_servers(&room_id) - .filter_map(Result::ok) - .collect(); + .map(ToOwned::to_owned) + .collect() + .await; servers.sort_unstable(); servers.dedup(); @@ -60,7 +63,11 @@ pub(crate) async fn get_room_information_route( pub(crate) async fn get_profile_information_route( State(services): State<crate::State>, body: Ruma<get_profile_information::v1::Request>, ) -> Result<get_profile_information::v1::Response> { - if !services.globals.allow_profile_lookup_federation_requests() { + if !services + .globals + .config + .allow_inbound_profile_lookup_federation_requests + { return Err(Error::BadRequest( ErrorKind::forbidden(), "Profile lookup over federation is not allowed on this homeserver.", @@ -82,30 +89,31 @@ pub(crate) async fn get_profile_information_route( match &body.field { Some(ProfileField::DisplayName) => { - displayname = services.users.displayname(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); }, Some(ProfileField::AvatarUrl) => { - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); }, Some(custom_field) => { - if let Some(value) = services + if let Ok(value) = services .users - .profile_key(&body.user_id, custom_field.as_str())? + .profile_key(&body.user_id, custom_field.as_str()) + .await { custom_profile_fields.insert(custom_field.to_string(), value); } }, None => { - displayname = services.users.displayname(&body.user_id)?; - avatar_url = services.users.avatar_url(&body.user_id)?; - blurhash = services.users.blurhash(&body.user_id)?; - tz = services.users.timezone(&body.user_id)?; + displayname = services.users.displayname(&body.user_id).await.ok(); + avatar_url = services.users.avatar_url(&body.user_id).await.ok(); + blurhash = services.users.blurhash(&body.user_id).await.ok(); + tz = services.users.timezone(&body.user_id).await.ok(); custom_profile_fields = services .users .all_profile_keys(&body.user_id) - .filter_map(Result::ok) - .collect(); + .collect() + .await; }, } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 15f82faa721b91d162d56ec69c9e89047110083a..2da99c93644187efcef28b993e2ecdc96d68196a 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -2,7 +2,8 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, debug_warn, err, trace, warn, Err}; +use conduit::{debug, debug_warn, err, error, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, Result}; +use futures::StreamExt; use ruma::{ api::{ client::error::ErrorKind, @@ -15,18 +16,22 @@ }, }, events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, + serde::Raw, to_device::DeviceIdOrAllDevices, OwnedEventId, ServerName, }; -use tokio::sync::RwLock; +use serde_json::value::RawValue as RawJsonValue; +use service::{ + sending::{EDU_LIMIT, PDU_LIMIT}, + Services, +}; use crate::{ - services::Services, utils::{self}, - Error, Result, Ruma, + Ruma, }; -type ResolvedMap = BTreeMap<OwnedEventId, Result<(), Error>>; +type ResolvedMap = BTreeMap<OwnedEventId, Result<()>>; /// # `PUT /_matrix/federation/v1/send/{txnId}` /// @@ -36,20 +41,22 @@ pub(crate) async fn send_transaction_message_route( State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<send_transaction_message::v1::Request>, ) -> Result<send_transaction_message::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - if *origin != body.body.origin { + if body.origin() != body.body.origin { return Err!(Request(Forbidden( "Not allowed to send transactions on behalf of other servers" ))); } - if body.pdus.len() > 50_usize { - return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); + if body.pdus.len() > PDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {PDU_LIMIT} PDUs in one transaction" + ))); } - if body.edus.len() > 100_usize { - return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); + if body.edus.len() > EDU_LIMIT { + return Err!(Request(Forbidden( + "Not allowed to send more than {EDU_LIMIT} EDUs in one transaction" + ))); } let txn_start_time = Instant::now(); @@ -58,37 +65,36 @@ pub(crate) async fn send_transaction_message_route( edus = ?body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, - origin =?body.origin, + origin =?body.origin(), "Starting txn", ); - let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; - handle_edus(&services, &client, &body, origin).await?; + let resolved_map = handle_pdus(&services, &client, &body.pdus, body.origin(), &txn_start_time).await?; + handle_edus(&services, &client, &body.edus, body.origin()).await; debug!( pdus = ?body.pdus.len(), edus = ?body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, - origin =?body.origin, + origin =?body.origin(), "Finished txn", ); Ok(send_transaction_message::v1::Response { pdus: resolved_map .into_iter() - .map(|(e, r)| (e, r.map_err(|e| e.sanitized_string()))) + .map(|(e, r)| (e, r.map_err(error::sanitized_message))) .collect(), }) } async fn handle_pdus( - services: &Services, _client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName, - txn_start_time: &Instant, + services: &Services, _client: &IpAddr, pdus: &[Box<RawJsonValue>], origin: &ServerName, txn_start_time: &Instant, ) -> Result<ResolvedMap> { - let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); - for pdu in &body.pdus { - parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { + let mut parsed_pdus = Vec::with_capacity(pdus.len()); + for pdu in pdus { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); @@ -100,24 +106,9 @@ async fn handle_pdus( // and hashes checks } - // We go through all the signatures we see on the PDUs and fetch the - // corresponding signing keys - let pub_key_map = RwLock::new(BTreeMap::new()); - if !parsed_pdus.is_empty() { - services - .server_keys - .fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) - .await - .unwrap_or_else(|e| warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}")); - - debug!( - elapsed = ?txn_start_time.elapsed(), - "Fetched signing keys" - ); - } - let mut resolved_map = BTreeMap::new(); for (event_id, value, room_id) in parsed_pdus { + services.server.check_running()?; let pdu_start_time = Instant::now(); let mutex_lock = services .rooms @@ -125,28 +116,28 @@ async fn handle_pdus( .mutex_federation .lock(&room_id) .await; - resolved_map.insert( - event_id.clone(), - services - .rooms - .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, true, &pub_key_map) - .await - .map(|_| ()), - ); - drop(mutex_lock); + let result = services + .rooms + .event_handler + .handle_incoming_pdu(origin, &room_id, &event_id, value, true) + .await + .map(|_| ()); + + drop(mutex_lock); debug!( pdu_elapsed = ?pdu_start_time.elapsed(), txn_elapsed = ?txn_start_time.elapsed(), "Finished PDU {event_id}", ); + + resolved_map.insert(event_id, result); } - for pdu in &resolved_map { - if let Err(e) = pdu.1 { + for (id, result) in &resolved_map { + if let Err(e) = result { if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { - warn!("Incoming PDU failed {pdu:?}"); + warn!("Incoming PDU failed {id}: {e:?}"); } } } @@ -154,35 +145,28 @@ async fn handle_pdus( Ok(resolved_map) } -async fn handle_edus( - services: &Services, client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName, -) -> Result<()> { - for edu in body - .edus +async fn handle_edus(services: &Services, client: &IpAddr, edus: &[Raw<Edu>], origin: &ServerName) { + for edu in edus .iter() .filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) { match edu { - Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, - Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, - Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, - Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, - Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, - Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, + Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await, + Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await, + Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, + Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await, + Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await, + Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await, Edu::_Custom(ref _custom) => { - debug_warn!(?body.edus, "received custom/unknown EDU"); + debug_warn!(?edus, "received custom/unknown EDU"); }, } } - - Ok(()) } -async fn handle_edu_presence( - services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, -) -> Result<()> { +async fn handle_edu_presence(services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent) { if !services.globals.allow_incoming_presence() { - return Ok(()); + return; } for update in presence.push { @@ -194,23 +178,24 @@ async fn handle_edu_presence( continue; } - services.presence.set_presence( - &update.user_id, - &update.presence, - Some(update.currently_active), - Some(update.last_active_ago), - update.status_msg.clone(), - )?; + services + .presence + .set_presence( + &update.user_id, + &update.presence, + Some(update.currently_active), + Some(update.last_active_ago), + update.status_msg.clone(), + ) + .await + .log_err() + .ok(); } - - Ok(()) } -async fn handle_edu_receipt( - services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, -) -> Result<()> { +async fn handle_edu_receipt(services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) { if !services.globals.allow_incoming_read_receipts() { - return Ok(()); + return; } for (room_id, room_updates) in receipt.receipts { @@ -218,6 +203,7 @@ async fn handle_edu_receipt( .rooms .event_handler .acl_check(origin, &room_id) + .await .is_err() { debug_warn!( @@ -240,8 +226,8 @@ async fn handle_edu_receipt( .rooms .state_cache .room_members(&room_id) - .filter_map(Result::ok) - .any(|member| member.server_name() == user_id.server_name()) + .ready_any(|member| member.server_name() == user_id.server_name()) + .await { for event_id in &user_updates.event_ids { let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); @@ -255,7 +241,8 @@ async fn handle_edu_receipt( services .rooms .read_receipt - .readreceipt_update(&user_id, &room_id, &event)?; + .readreceipt_update(&user_id, &room_id, &event) + .await; } } else { debug_warn!( @@ -266,15 +253,11 @@ async fn handle_edu_receipt( } } } - - Ok(()) } -async fn handle_edu_typing( - services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, -) -> Result<()> { +async fn handle_edu_typing(services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent) { if !services.globals.config.allow_incoming_typing { - return Ok(()); + return; } if typing.user_id.server_name() != origin { @@ -282,26 +265,28 @@ async fn handle_edu_typing( %typing.user_id, %origin, "received typing EDU for user not belonging to origin" ); - return Ok(()); + return; } if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) + .await .is_err() { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server" ); - return Ok(()); + return; } if services .rooms .state_cache - .is_joined(&typing.user_id, &typing.room_id)? + .is_joined(&typing.user_id, &typing.room_id) + .await { if typing.typing { let timeout = utils::millis_since_unix_epoch().saturating_add( @@ -315,28 +300,29 @@ async fn handle_edu_typing( .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) - .await?; + .await + .log_err() + .ok(); } else { services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) - .await?; + .await + .log_err() + .ok(); } } else { debug_warn!( %typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room" ); - return Ok(()); } - - Ok(()) } async fn handle_edu_device_list_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, -) -> Result<()> { +) { let DeviceListUpdateContent { user_id, .. @@ -347,17 +333,15 @@ async fn handle_edu_device_list_update( %user_id, %origin, "received device list update EDU for user not belonging to origin" ); - return Ok(()); + return; } - services.users.mark_device_key_update(&user_id)?; - - Ok(()) + services.users.mark_device_key_update(&user_id).await; } async fn handle_edu_direct_to_device( services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, -) -> Result<()> { +) { let DirectDeviceContent { sender, ev_type, @@ -370,45 +354,52 @@ async fn handle_edu_direct_to_device( %sender, %origin, "received direct to device EDU for user not belonging to origin" ); - return Ok(()); + return; } // Check if this is a new transaction id if services .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() + .existing_txnid(&sender, None, &message_id) + .await + .is_ok() { - return Ok(()); + return; } for (target_user_id, map) in &messages { for (target_device_id_maybe, event) in map { + let Ok(event) = event + .deserialize_as() + .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}"))))) + else { + continue; + }; + + let ev_type = ev_type.to_string(); match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services.users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?, - )?; + services + .users + .add_to_device_event(&sender, target_user_id, target_device_id, &ev_type, event) + .await; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services.users.all_device_ids(target_user_id) { - services.users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event - .deserialize_as() - .map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, - )?; - } + let (sender, ev_type, event) = (&sender, &ev_type, &event); + services + .users + .all_device_ids(target_user_id) + .for_each(|target_device_id| { + services.users.add_to_device_event( + sender, + target_user_id, + target_device_id, + ev_type, + event.clone(), + ) + }) + .await; }, } } @@ -417,14 +408,12 @@ async fn handle_edu_direct_to_device( // Save transaction id with empty data services .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - - Ok(()) + .add_txnid(&sender, None, &message_id, &[]); } async fn handle_edu_signing_key_update( services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, -) -> Result<()> { +) { let SigningKeyUpdateContent { user_id, master_key, @@ -436,14 +425,15 @@ async fn handle_edu_signing_key_update( %user_id, %origin, "received signing key update EDU from server that does not belong to user's server" ); - return Ok(()); + return; } if let Some(master_key) = master_key { services .users - .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true) + .await + .log_err() + .ok(); } - - Ok(()) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c4d016f61c84f02e64a67d7d4cc0913c47a39096..60ec8c1f48947f79b43604e2f20f0d1154555e4c 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,20 +1,20 @@ #![allow(deprecated)] -use std::collections::BTreeMap; +use std::borrow::Borrow; use axum::extract::State; -use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; +use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, + CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; -use tokio::sync::RwLock; use crate::Ruma; @@ -22,27 +22,29 @@ async fn create_join_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<create_join_event::v1::RoomState> { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin server - services.rooms.event_handler.acl_check(origin, room_id)?; + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; // We need to return the state prior to joining, let's keep a reference to that // here let shortstatehash = services .rooms .state - .get_room_shortstatehash(room_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Event state not found."))?; - - let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); + .get_room_shortstatehash(room_id) + .await + .map_err(|_| err!(Request(NotFound("Event state not found."))))?; // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json @@ -97,7 +99,8 @@ async fn create_join_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; // check if origin server is trying to send for another server if sender.server_name() != origin { @@ -126,22 +129,16 @@ async fn create_join_event( if content .join_authorized_via_users_server .is_some_and(|user| services.globals.user_is_local(&user)) - && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() + && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id) + .await + .unwrap_or_default() { - ruma::signatures::hash_and_sign_event( - services.globals.server_name().as_str(), - services.globals.keypair(), - &mut value, - &room_version_id, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + services + .server_keys + .hash_and_sign_event(&mut value, &room_version_id) + .map_err(|e| err!(Request(InvalidParam("Failed to sign event: {e}"))))?; } - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - let origin: OwnedServerName = serde_json::from_value( serde_json::to_value( value @@ -158,12 +155,14 @@ async fn create_join_event( .mutex_federation .lock(room_id) .await; - let pdu_id: Vec<u8> = services + + let pdu_id = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) + .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; + drop(mutex_lock); let state_ids = services @@ -171,29 +170,44 @@ async fn create_join_event( .state_accessor .state_full_ids(shortstatehash) .await?; - let auth_chain_ids = services + + let state = state_ids + .iter() + .try_stream() + .and_then(|(_, event_id)| services.rooms.timeline.get_pdu_json(event_id)) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() + .await?; + + let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect(); + let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .event_ids_iter(room_id, &starting_events) + .await? + .map(Ok) + .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; - services.sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id).await?; Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - .collect(), + auth_chain, + state, // Event field is required if the room version supports restricted join rules. - event: Some( - to_raw_value(&CanonicalJsonValue::Object(value)) - .expect("To raw json should not fail since only change was adding signature"), - ), + event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), }) } @@ -203,16 +217,15 @@ async fn create_join_event( pub(crate) async fn create_join_event_v1_route( State(services): State<crate::State>, body: Ruma<create_join_event::v1::Request>, ) -> Result<create_join_event::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { warn!( - "Server {origin} tried joining room ID {} who has a server name that is globally forbidden. Rejecting.", + "Server {} tried joining room ID {} who has a server name that is globally forbidden. Rejecting.", + body.origin(), &body.room_id, ); return Err(Error::BadRequest( @@ -229,8 +242,8 @@ pub(crate) async fn create_join_event_v1_route( .contains(&server.to_owned()) { warn!( - "Server {origin} tried joining room ID {} which has a server name that is globally forbidden. \ - Rejecting.", + "Server {} tried joining room ID {} which has a server name that is globally forbidden. Rejecting.", + body.origin(), &body.room_id, ); return Err(Error::BadRequest( @@ -240,7 +253,9 @@ pub(crate) async fn create_join_event_v1_route( } } - let room_state = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&services, body.origin(), &body.room_id, &body.pdu) + .boxed() + .await?; Ok(create_join_event::v1::Response { room_state, @@ -253,13 +268,11 @@ pub(crate) async fn create_join_event_v1_route( pub(crate) async fn create_join_event_v2_route( State(services): State<crate::State>, body: Ruma<create_join_event::v2::Request>, ) -> Result<create_join_event::v2::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - if services .globals .config .forbidden_remote_server_names - .contains(origin) + .contains(body.origin()) { return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -285,7 +298,9 @@ pub(crate) async fn create_join_event_v2_route( auth_chain, state, event, - } = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; + } = create_join_event(&services, body.origin(), &body.room_id, &body.pdu) + .boxed() + .await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs new file mode 100644 index 0000000000000000000000000000000000000000..c57998aec5d46a9742d97fe396884af522bc4540 --- /dev/null +++ b/src/api/server/send_knock.rs @@ -0,0 +1,190 @@ +use axum::extract::State; +use conduit::{err, pdu::gen_event_id_canonical_json, warn, Err, Error, PduEvent, Result}; +use ruma::{ + api::{client::error::ErrorKind, federation::knock::send_knock}, + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + StateEventType, + }, + serde::JsonObject, + OwnedServerName, OwnedUserId, + RoomVersionId::*, +}; + +use crate::Ruma; + +/// # `PUT /_matrix/federation/v1/send_knock/{roomId}/{eventId}` +/// +/// Submits a signed knock event. +pub(crate) async fn create_knock_event_v1_route( + State(services): State<crate::State>, body: Ruma<send_knock::v1::Request>, +) -> Result<send_knock::v1::Response> { + if services + .globals + .config + .forbidden_remote_server_names + .contains(body.origin()) + { + warn!( + "Server {} tried knocking room ID {} who has a server name that is globally forbidden. Rejecting.", + body.origin(), + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + + if let Some(server) = body.room_id.server_name() { + if services + .globals + .config + .forbidden_remote_server_names + .contains(&server.to_owned()) + { + warn!( + "Server {} tried knocking room ID {} which has a server name that is globally forbidden. Rejecting.", + body.origin(), + &body.room_id, + ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); + } + } + + if !services.rooms.metadata.exists(&body.room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } + + // ACL check origin server + services + .rooms + .event_handler + .acl_check(body.origin(), &body.room_id) + .await?; + + let room_version_id = services.rooms.state.get_room_version(&body.room_id).await?; + + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6) { + return Err!(Request(Forbidden("Room version does not support knocking."))); + } + + let Ok((event_id, value)) = gen_event_id_canonical_json(&body.pdu, &room_version_id) else { + // Event could not be converted to canonical json + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); + }; + + let event_type: StateEventType = serde_json::from_value( + value + .get("type") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing type property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event has invalid event type."))?; + + if event_type != StateEventType::RoomMember { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to send non-membership state event to knock endpoint.", + )); + } + + let content: RoomMemberEventContent = serde_json::from_value( + value + .get("content") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing content property"))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event content is empty or invalid"))?; + + if content.membership != MembershipState::Knock { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to send a non-knock membership event to knock endpoint.", + )); + } + + // ACL check sender server name + let sender: OwnedUserId = serde_json::from_value( + value + .get("sender") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing sender property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?; + + services + .rooms + .event_handler + .acl_check(sender.server_name(), &body.room_id) + .await?; + + // check if origin server is trying to send for another server + if sender.server_name() != body.origin() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Not allowed to knock on behalf of another server.", + )); + } + + let state_key: OwnedUserId = serde_json::from_value( + value + .get("state_key") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing state_key property."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "state_key is invalid or not a user ID."))?; + + if state_key != sender { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "State key does not match sender user", + )); + }; + + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value( + value + .get("origin") + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?, + ) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; + + let mut event: JsonObject = serde_json::from_str(body.pdu.get()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + + event.insert("event_id".to_owned(), "$placeholder".into()); + + let pdu: PduEvent = serde_json::from_value(event.into()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid knock event PDU."))?; + + let mutex_lock = services + .rooms + .event_handler + .mutex_federation + .lock(&body.room_id) + .await; + + let pdu_id = services + .rooms + .event_handler + .handle_incoming_pdu(&origin, &body.room_id, &event_id, value.clone(), true) + .await? + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; + + drop(mutex_lock); + + let knock_room_state = services.rooms.state.summary_stripped(&pdu).await; + + services + .sending + .send_pdu_room(&body.room_id, &pdu_id) + .await?; + + Ok(send_knock::v1::Response { + knock_room_state, + }) +} diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index e77c5d78a99ab3c06569c4256359e54d5f302fe4..e4f41833ceed8b3a8849d92f642e3ee32e24f180 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,19 +1,16 @@ #![allow(deprecated)] -use std::collections::BTreeMap; - use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{err, utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - OwnedServerName, OwnedUserId, RoomId, ServerName, + OwnedUserId, RoomId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::RwLock; use crate::{ service::{pdu::gen_event_id_canonical_json, Services}, @@ -26,9 +23,7 @@ pub(crate) async fn create_leave_event_v1_route( State(services): State<crate::State>, body: Ruma<create_leave_event::v1::Request>, ) -> Result<create_leave_event::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, body.origin(), &body.room_id, &body.pdu).await?; Ok(create_leave_event::v1::Response::new()) } @@ -39,9 +34,7 @@ pub(crate) async fn create_leave_event_v1_route( pub(crate) async fn create_leave_event_v2_route( State(services): State<crate::State>, body: Ruma<create_leave_event::v2::Request>, ) -> Result<create_leave_event::v2::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, body.origin(), &body.room_id, &body.pdu).await?; Ok(create_leave_event::v2::Response::new()) } @@ -49,18 +42,20 @@ pub(crate) async fn create_leave_event_v2_route( async fn create_leave_event( services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<()> { - if !services.rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id).await { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin - services.rooms.event_handler.acl_check(origin, room_id)?; - - let pub_key_map = RwLock::new(BTreeMap::new()); + services + .rooms + .event_handler + .acl_check(origin, room_id) + .await?; // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id).await?; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -114,7 +109,8 @@ async fn create_leave_event( services .rooms .event_handler - .acl_check(sender.server_name(), room_id)?; + .acl_check(sender.server_name(), room_id) + .await?; if sender.server_name() != origin { return Err(Error::BadRequest( @@ -139,33 +135,19 @@ async fn create_leave_event( )); } - let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value( - value - .get("origin") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event missing origin property."))?, - ) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; - - services - .server_keys - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - let mutex_lock = services .rooms .event_handler .mutex_federation .lock(room_id) .await; - let pdu_id: Vec<u8> = services + + let pdu_id = services .rooms .event_handler - .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) + .handle_incoming_pdu(origin, room_id, &event_id, value, true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); @@ -173,10 +155,7 @@ async fn create_leave_event( .rooms .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server| !services.globals.server_is_ours(server)); - - services.sending.send_pdu_servers(servers, &pdu_id)?; + .ready_filter(|server| !services.globals.server_is_ours(server)); - Ok(()) + services.sending.send_pdu_servers(servers, &pdu_id).await } diff --git a/src/api/server/state.rs b/src/api/server/state.rs index d215236afddf1b2a201c0cb1945dd10e1b646689..06a44a999ca398bbea37b8112fc71cd2263fe723 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,9 +1,11 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; -use conduit::{Error, Result}; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; +use conduit::{err, result::LogErr, utils::IterStream, Result}; +use futures::{FutureExt, StreamExt, TryStreamExt}; +use ruma::api::federation::event::get_room_state; +use super::AccessCheck; use crate::Ruma; /// # `GET /_matrix/federation/v1/state/{roomId}` @@ -12,61 +14,59 @@ pub(crate) async fn get_room_state_route( State(services): State<crate::State>, body: Ruma<get_room_state::v1::Request>, ) -> Result<get_room_state::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - services - .rooms - .event_handler - .acl_check(origin, &body.room_id)?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; let pdus = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| { + .await + .log_err() + .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? + .values() + .try_stream() + .and_then(|id| services.rooms.timeline.get_pdu_json(id)) + .and_then(|pdu| { services .sending - .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + .convert_to_outgoing_federation_event(pdu) + .map(Ok) }) - .collect(); + .try_collect() + .await?; - let auth_chain_ids = services + let auth_chain = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .await? + .map(Ok) + .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) + .and_then(|pdu| { + services + .sending + .convert_to_outgoing_federation_event(pdu) + .map(Ok) + }) + .try_collect() .await?; Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| { - services - .rooms - .timeline - .get_pdu_json(&id) - .ok()? - .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) - }) - .collect(), + auth_chain, pdus, }) } diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index d22f2df4afb4ba26a207d7cd996bc8e4f9bab8dc..52d8e7cca8da5bb252eae732d086ebda9306ff78 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,9 +1,12 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; -use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; +use conduit::{err, Result}; +use futures::StreamExt; +use ruma::api::federation::event::get_room_state_ids; -use crate::{Error, Result, Ruma}; +use super::AccessCheck; +use crate::Ruma; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// @@ -12,36 +15,28 @@ pub(crate) async fn get_room_state_ids_route( State(services): State<crate::State>, body: Ruma<get_room_state_ids::v1::Request>, ) -> Result<get_room_state_ids::v1::Response> { - let origin = body.origin.as_ref().expect("server is authenticated"); - - services - .rooms - .event_handler - .acl_check(origin, &body.room_id)?; - - if !services - .rooms - .state_accessor - .is_world_readable(&body.room_id)? - && !services - .rooms - .state_cache - .server_in_room(origin, &body.room_id)? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); + AccessCheck { + services: &services, + origin: body.origin(), + room_id: &body.room_id, + event_id: None, } + .check() + .await?; let shortstatehash = services .rooms .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; + .pdu_shortstatehash(&body.event_id) + .await + .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; let pdu_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await? + .await + .map_err(|_| err!(Request(NotFound("State ids not found"))))? .into_values() .map(|id| (*id).to_owned()) .collect(); @@ -49,11 +44,14 @@ pub(crate) async fn get_room_state_ids_route( let auth_chain_ids = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .await? + .map(|id| (*id).to_owned()) + .collect() + .await; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + auth_chain_ids, pdu_ids, }) } diff --git a/src/api/server/user.rs b/src/api/server/user.rs index e9a400a79eddbb3c53e23b6fcc1ee1b5a1c0f087..40f330a121810ead55a14069ba0c49c005201363 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,5 +1,6 @@ use axum::extract::State; use conduit::{Error, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -26,43 +27,51 @@ pub(crate) async fn get_devices_route( )); } - let origin = body.origin.as_ref().expect("server is authenticated"); - + let user_id = &body.user_id; Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), + user_id: user_id.clone(), stream_id: services .users - .get_devicelist_version(&body.user_id)? + .get_devicelist_version(user_id) + .await .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), + .try_into()?, devices: services .users - .all_devices_metadata(&body.user_id) - .filter_map(Result::ok) - .filter_map(|metadata| { - let device_id_string = metadata.device_id.as_str().to_owned(); + .all_devices_metadata(user_id) + .filter_map(|metadata| async move { + let device_id = metadata.device_id.clone(); + let device_id_clone = device_id.clone(); + let device_id_string = device_id.as_str().to_owned(); let device_display_name = if services.globals.allow_device_name_federation() { - metadata.display_name + metadata.display_name.clone() } else { Some(device_id_string) }; - Some(UserDevice { - keys: services - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name, - }) + + services + .users + .get_device_keys(user_id, &device_id_clone) + .map_ok(|keys| UserDevice { + device_id, + keys, + device_display_name, + }) + .map(Result::ok) + .await }) - .collect(), + .collect() + .await, master_key: services .users - .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_master_key(None, &body.user_id, &|u| u.server_name() == body.origin()) + .await + .ok(), self_signing_key: services .users - .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == body.origin()) + .await + .ok(), }) } diff --git a/src/api/server/utils.rs b/src/api/server/utils.rs new file mode 100644 index 0000000000000000000000000000000000000000..278465caae60e7ae90ff7923a4c92ccc49dc31cf --- /dev/null +++ b/src/api/server/utils.rs @@ -0,0 +1,60 @@ +use conduit::{implement, is_false, Err, Result}; +use conduit_service::Services; +use futures::{future::OptionFuture, join, FutureExt}; +use ruma::{EventId, RoomId, ServerName}; + +pub(super) struct AccessCheck<'a> { + pub(super) services: &'a Services, + pub(super) origin: &'a ServerName, + pub(super) room_id: &'a RoomId, + pub(super) event_id: Option<&'a EventId>, +} + +#[implement(AccessCheck, params = "<'_>")] +pub(super) async fn check(&self) -> Result { + let acl_check = self + .services + .rooms + .event_handler + .acl_check(self.origin, self.room_id) + .map(|result| result.is_ok()); + + let world_readable = self + .services + .rooms + .state_accessor + .is_world_readable(self.room_id); + + let server_in_room = self + .services + .rooms + .state_cache + .server_in_room(self.origin, self.room_id); + + let server_can_see: OptionFuture<_> = self + .event_id + .map(|event_id| { + self.services + .rooms + .state_accessor + .server_can_see_event(self.origin, self.room_id, event_id) + }) + .into(); + + let (world_readable, server_in_room, server_can_see, acl_check) = + join!(world_readable, server_in_room, server_can_see, acl_check); + + if !acl_check { + return Err!(Request(Forbidden("Server access denied."))); + } + + if !world_readable && !server_in_room { + return Err!(Request(Forbidden("Server is not in room."))); + } + + if server_can_see.is_some_and(is_false!()) { + return Err!(Request(Forbidden("Server is not allowed to see event."))); + } + + Ok(()) +} diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 713647342460d7fd1d8f3ff056598e70b79304ca..b93f9a7775638ee8295718c118fce14219d32f18 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -57,6 +57,7 @@ argon2.workspace = true arrayvec.workspace = true axum.workspace = true bytes.workspace = true +bytesize.workspace = true cargo_toml.workspace = true checked_ops.workspace = true chrono.workspace = true @@ -67,6 +68,7 @@ ctor.workspace = true cyborgtime.workspace = true either.workspace = true figment.workspace = true +futures.workspace = true http-body-util.workspace = true http.workspace = true image.workspace = true @@ -82,6 +84,7 @@ ruma.workspace = true sanitize-filename.workspace = true serde_json.workspace = true serde_regex.workspace = true +serde_yaml.workspace = true serde.workspace = true thiserror.workspace = true tikv-jemallocator.optional = true diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 8dea55d837135e9fc9aeef767c613677a9802f60..c0d05533701f78bbde98ecca2dafad20f9cb8510 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -94,6 +94,22 @@ pub fn check(config: &Config) -> Result<()> { )); } + // check if we can read the token file path, and check if the file is empty + if config.registration_token_file.as_ref().is_some_and(|path| { + let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| { + error!("Failed to read the registration token file: {e}"); + }) else { + return true; + }; + + token == String::new() + }) { + return Err!(Config( + "registration_token_file", + "Registration token file was specified but is empty or failed to be read" + )); + } + if config.max_request_size < 5_120_000 { return Err!(Config( "max_request_size", @@ -111,12 +127,13 @@ pub fn check(config: &Config) -> Result<()> { if config.allow_registration && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() + && config.registration_token_file.is_none() { return Err!(Config( "registration_token", "!! You have `allow_registration` enabled without a token configured in your config which means you are \ allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n -If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n +If this is not the intended behaviour, please set a registration token.\n For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ want, please set the following config option to true: `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`" @@ -126,6 +143,7 @@ pub fn check(config: &Config) -> Result<()> { if config.allow_registration && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() + && config.registration_token_file.is_none() { warn!( "Open registration is enabled via setting \ diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index d2d583a8c49d0d6cfbe73a01f20491a4a5b8891d..cb9d087bbd46ae5dae3ee4c1d1bb34ff014da7b3 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1,10 +1,14 @@ +pub mod check; +pub mod proxy; + use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, fmt, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::PathBuf, }; +use conduit_macros::config_example_generator; use either::{ Either, Either::{Left, Right}, @@ -14,363 +18,1466 @@ use itertools::Itertools; use regex::RegexSet; use ruma::{ - api::client::discovery::discover_support::ContactRole, OwnedRoomId, OwnedServerName, OwnedUserId, RoomVersionId, + api::client::discovery::discover_support::ContactRole, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, + RoomVersionId, }; use serde::{de::IgnoredAny, Deserialize}; use url::Url; pub use self::check::check; use self::proxy::ProxyConfig; -use crate::{error::Error, utils::sys, Err, Result}; - -pub mod check; -pub mod proxy; +use crate::{err, error::Error, utils::sys, Result}; /// all the config options for conduwuit -#[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] +#[allow(rustdoc::broken_intra_doc_links, rustdoc::bare_urls)] +#[derive(Clone, Debug, Deserialize)] +#[config_example_generator( + filename = "conduwuit-example.toml", + section = "global", + undocumented = "# This item is undocumented. Please contribute documentation for it.", + header = "### conduwuit Configuration\n###\n### THIS FILE IS GENERATED. CHANGES/CONTRIBUTIONS IN THE REPO WILL\n### BE \ + OVERWRITTEN!\n###\n### You should rename this file before configuring your server. Changes\n### to \ + documentation and defaults can be contributed in source code at\n### src/core/config/mod.rs. This file \ + is generated when building.\n###\n### Any values pre-populated are the default values for said config \ + option.\n###\n### At the minimum, you MUST edit all the config options to your environment\n### that say \ + \"YOU NEED TO EDIT THIS\".\n### See https://conduwuit.puppyirl.gay/configuration.html for ways to\n### configure conduwuit\n", + ignore = "catchall well_known tls" +)] pub struct Config { - /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) + /// The server_name is the pretty name of this server. It is used as a + /// suffix for user and room IDs/aliases. + /// + /// See the docs for reverse proxying and delegation: https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy + /// Also see the `[global.well_known]` config section at the very bottom. + /// + /// Examples of delegation: + /// - https://puppygock.gay/.well-known/matrix/server + /// - https://puppygock.gay/.well-known/matrix/client + /// + /// YOU NEED TO EDIT THIS. THIS CANNOT BE CHANGED AFTER WITHOUT A DATABASE + /// WIPE. + /// + /// example: "conduwuit.woof" + pub server_name: OwnedServerName, + + /// default address (IPv4 or IPv6) conduwuit will listen on. + /// + /// If you are using Docker or a container NAT networking setup, this must + /// be "0.0.0.0". + /// + /// To listen on multiple addresses, specify a vector e.g. ["127.0.0.1", + /// "::1"] + /// + /// default: ["127.0.0.1", "::1"] #[serde(default = "default_address")] address: ListeningAddr, - /// default TCP port(s) conduwuit will listen on + + /// The port(s) conduwuit will be running on. + /// + /// See https://conduwuit.puppyirl.gay/deploying/generic.html#setting-up-the-reverse-proxy for reverse proxying. + /// + /// Docker users: Don't change this, you'll need to map an external port to + /// this. + /// + /// To listen on multiple ports, specify a vector e.g. [8080, 8448] + /// + /// default: 8008 #[serde(default = "default_port")] port: ListeningPort, + + // external structure; separate section pub tls: Option<TlsConfig>, + + /// Uncomment unix_socket_path to listen on a UNIX socket at the specified + /// path. If listening on a UNIX socket, you MUST remove/comment the + /// 'address' key if definedm AND add your reverse proxy to the 'conduwuit' + /// group, unless world RW permissions are specified with unix_socket_perms + /// (666 minimum). + /// + /// example: "/run/conduwuit/conduwuit.sock" pub unix_socket_path: Option<PathBuf>, + + /// The default permissions (in octal) to create the UNIX socket with. + /// + /// default: 660 #[serde(default = "default_unix_socket_perms")] pub unix_socket_perms: u32, - pub server_name: OwnedServerName, - #[serde(default = "default_database_backend")] - pub database_backend: String, + + /// This is the only directory where conduwuit will save its data, including + /// media. + /// Note: this was previously "/var/lib/matrix-conduit" + /// + /// YOU NEED TO EDIT THIS. + /// + /// example: "/var/lib/conduwuit" pub database_path: PathBuf, + + /// conduwuit supports online database backups using RocksDB's Backup engine + /// API. To use this, set a database backup path that conduwuit can write + /// to. + /// + /// See https://conduwuit.puppyirl.gay/maintenance.html#backups for more information. + /// + /// example: "/opt/conduwuit-db-backups" pub database_backup_path: Option<PathBuf>, + + /// The amount of online RocksDB database backups to keep/retain, if using + /// "database_backup_path", before deleting the oldest one. + /// + /// default: 1 #[serde(default = "default_database_backups_to_keep")] pub database_backups_to_keep: i16, + + /// Set this to any float value in megabytes for conduwuit to tell the + /// database engine that this much memory is available for database-related + /// caches. + /// + /// May be useful if you have significant memory to spare to increase + /// performance. + /// + /// Similar to the individual LRU caches, this is scaled up with your CPU + /// core count. + /// + /// This defaults to 128.0 + (64.0 * CPU core count) #[serde(default = "default_db_cache_capacity_mb")] pub db_cache_capacity_mb: f64, + + /// Option to control adding arbitrary text to the end of the user's + /// displayname upon registration with a space before the text. This was the + /// lightning bolt emoji option, just replaced with support for adding your + /// own custom text or emojis. To disable, set this to "" (an empty string). + /// + /// The default is the trans pride flag. + /// + /// example: "ðŸ³ï¸âš§ï¸" + /// + /// default: "ðŸ³ï¸âš§ï¸" #[serde(default = "default_new_user_displayname_suffix")] pub new_user_displayname_suffix: String, + + /// If enabled, conduwuit will send a simple GET request periodically to + /// `https://pupbrain.dev/check-for-updates/stable` for any new + /// announcements made. Despite the name, this is not an update check + /// endpoint, it is simply an announcement check endpoint. + /// + /// This is disabled by default as this is rarely used except for security + /// updates or major updates. #[serde(default)] pub allow_check_for_updates: bool, - #[serde(default = "default_pdu_cache_capacity")] - pub pdu_cache_capacity: u32, + /// Set this to any float value to multiply conduwuit's in-memory LRU caches + /// with such as "auth_chain_cache_capacity". + /// + /// May be useful if you have significant memory to spare to increase + /// performance. This was previously called + /// `conduit_cache_capacity_modifier`. + /// + /// If you have low memory, reducing this may be viable. + /// + /// By default, the individual caches such as "auth_chain_cache_capacity" + /// are scaled by your CPU core count. + /// + /// default: 1.0 #[serde(default = "default_cache_capacity_modifier", alias = "conduit_cache_capacity_modifier")] pub cache_capacity_modifier: f64, + + /// default: varies by system + #[serde(default = "default_pdu_cache_capacity")] + pub pdu_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_auth_chain_cache_capacity")] pub auth_chain_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_shorteventid_cache_capacity")] pub shorteventid_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_eventidshort_cache_capacity")] pub eventidshort_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_shortstatekey_cache_capacity")] pub shortstatekey_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_statekeyshort_cache_capacity")] pub statekeyshort_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_server_visibility_cache_capacity")] pub server_visibility_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_user_visibility_cache_capacity")] pub user_visibility_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_stateinfo_cache_capacity")] pub stateinfo_cache_capacity: u32, + + /// default: varies by system #[serde(default = "default_roomid_spacehierarchy_cache_capacity")] pub roomid_spacehierarchy_cache_capacity: u32, + /// Maximum entries stored in DNS memory-cache. The size of an entry may + /// vary so please take care if raising this value excessively. Only + /// decrease this when using an external DNS cache. Please note + /// that systemd-resolved does *not* count as an external cache, even when + /// configured to do so. + /// + /// default: 32768 #[serde(default = "default_dns_cache_entries")] pub dns_cache_entries: u32, + + /// Minimum time-to-live in seconds for entries in the DNS cache. The + /// default may appear high to most administrators; this is by design as the + /// majority of NXDOMAINs are correct for a long time (e.g. the server is no + /// longer running Matrix). Only decrease this if you are using an external + /// DNS cache. + /// + /// default_dns_min_ttl: 259200 #[serde(default = "default_dns_min_ttl")] pub dns_min_ttl: u64, + + /// Minimum time-to-live in seconds for NXDOMAIN entries in the DNS cache. + /// This value is critical for the server to federate efficiently. + /// NXDOMAIN's are assumed to not be returning to the federation + /// and aggressively cached rather than constantly rechecked. + /// + /// Defaults to 3 days as these are *very rarely* false negatives. + /// + /// default: 259200 #[serde(default = "default_dns_min_ttl_nxdomain")] pub dns_min_ttl_nxdomain: u64, + + /// Number of retries after a timeout. + /// + /// default: 10 #[serde(default = "default_dns_attempts")] pub dns_attempts: u16, + + /// The number of seconds to wait for a reply to a DNS query. Please note + /// that recursive queries can take up to several seconds for some domains, + /// so this value should not be too low, especially on slower hardware or + /// resolvers. + /// + /// default: 10 #[serde(default = "default_dns_timeout")] pub dns_timeout: u64, + + /// Fallback to TCP on DNS errors. Set this to false if unsupported by + /// nameserver. #[serde(default = "true_fn")] pub dns_tcp_fallback: bool, + + /// Enable to query all nameservers until the domain is found. Referred to + /// as "trust_negative_responses" in hickory_resolver. This can avoid + /// useless DNS queries if the first nameserver responds with NXDOMAIN or + /// an empty NOERROR response. #[serde(default = "true_fn")] pub query_all_nameservers: bool, + + /// Enables using *only* TCP for querying your specified nameservers instead + /// of UDP. + /// + /// If you are running conduwuit in a container environment, this config option may need to be enabled. See https://conduwuit.puppyirl.gay/troubleshooting.html#potential-dns-issues-when-using-docker for more details. #[serde(default)] pub query_over_tcp_only: bool, + + /// DNS A/AAAA record lookup strategy + /// + /// Takes a number of one of the following options: + /// 1 - Ipv4Only (Only query for A records, no AAAA/IPv6) + /// + /// 2 - Ipv6Only (Only query for AAAA records, no A/IPv4) + /// + /// 3 - Ipv4AndIpv6 (Query for A and AAAA records in parallel, uses whatever + /// returns a successful response first) + /// + /// 4 - Ipv6thenIpv4 (Query for AAAA record, if that fails then query the A + /// record) + /// + /// 5 - Ipv4thenIpv6 (Query for A record, if that fails then query the AAAA + /// record) + /// + /// If you don't have IPv6 networking, then for better DNS performance it + /// may be suitable to set this to Ipv4Only (1) as you will never ever use + /// the AAAA record contents even if the AAAA record is successful instead + /// of the A record. + /// + /// default: 5 #[serde(default = "default_ip_lookup_strategy")] pub ip_lookup_strategy: u8, + /// Max request size for file uploads in bytes. Defaults to 20MB. + /// + /// default: 20971520 #[serde(default = "default_max_request_size")] pub max_request_size: usize, + + /// default: 192 #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, + /// Default/base connection timeout (seconds). This is used only by URL + /// previews and update/news endpoint checks. + /// + /// default: 10 #[serde(default = "default_request_conn_timeout")] pub request_conn_timeout: u64, + + /// Default/base request timeout (seconds). The time waiting to receive more + /// data from another server. This is used only by URL previews, + /// update/news, and misc endpoint checks. + /// + /// default: 35 #[serde(default = "default_request_timeout")] pub request_timeout: u64, + + /// Default/base request total timeout (seconds). The time limit for a whole + /// request. This is set very high to not cancel healthy requests while + /// serving as a backstop. This is used only by URL previews and + /// update/news endpoint checks. + /// + /// default: 320 #[serde(default = "default_request_total_timeout")] pub request_total_timeout: u64, + + /// Default/base idle connection pool timeout (seconds). This is used only + /// by URL previews and update/news endpoint checks. + /// + /// default: 5 #[serde(default = "default_request_idle_timeout")] pub request_idle_timeout: u64, + + /// Default/base max idle connections per host. This is used only by URL + /// previews and update/news endpoint checks. Defaults to 1 as generally the + /// same open connection can be re-used. + /// + /// default: 1 #[serde(default = "default_request_idle_per_host")] pub request_idle_per_host: u16, + + /// Federation well-known resolution connection timeout (seconds) + /// + /// default: 6 #[serde(default = "default_well_known_conn_timeout")] pub well_known_conn_timeout: u64, + + /// Federation HTTP well-known resolution request timeout (seconds) + /// + /// default: 10 #[serde(default = "default_well_known_timeout")] pub well_known_timeout: u64, + + /// Federation client request timeout (seconds). You most definitely want + /// this to be high to account for extremely large room joins, slow + /// homeservers, your own resources etc. + /// + /// default: 300 #[serde(default = "default_federation_timeout")] pub federation_timeout: u64, + + /// Federation client idle connection pool timeout (seconds) + /// + /// default: 25 #[serde(default = "default_federation_idle_timeout")] pub federation_idle_timeout: u64, + + /// Federation client max idle connections per host. Defaults to 1 as + /// generally the same open connection can be re-used + /// + /// default: 1 #[serde(default = "default_federation_idle_per_host")] pub federation_idle_per_host: u16, + + /// Federation sender request timeout (seconds). The time it takes for the + /// remote server to process sent transactions can take a while. + /// + /// default: 180 #[serde(default = "default_sender_timeout")] pub sender_timeout: u64, + + /// Federation sender idle connection pool timeout (seconds) + /// + /// default: 180 #[serde(default = "default_sender_idle_timeout")] pub sender_idle_timeout: u64, + + /// Federation sender transaction retry backoff limit (seconds) + /// + /// default: 86400 #[serde(default = "default_sender_retry_backoff_limit")] pub sender_retry_backoff_limit: u64, + + /// Appservice URL request connection timeout. Defaults to 35 seconds as + /// generally appservices are hosted within the same network. + /// + /// default: 35 #[serde(default = "default_appservice_timeout")] pub appservice_timeout: u64, + + /// Appservice URL idle connection pool timeout (seconds) + /// + /// default: 300 #[serde(default = "default_appservice_idle_timeout")] pub appservice_idle_timeout: u64, + + /// Notification gateway pusher idle connection pool timeout + /// + /// default: 15 #[serde(default = "default_pusher_idle_timeout")] pub pusher_idle_timeout: u64, + /// Enables registration. If set to false, no users can register on this + /// server. + /// + /// If set to true without a token configured, users can register with no + /// form of 2nd-step only if you set + /// `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` to + /// true in your config. + /// + /// If you would like registration only via token reg, please configure + /// `registration_token` or `registration_token_file`. #[serde(default)] pub allow_registration: bool, + #[serde(default)] pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, + + /// A static registration token that new users will have to provide when + /// creating an account. If unset and `allow_registration` is true, + /// registration is open without any condition. + /// + /// YOU NEED TO EDIT THIS OR USE registration_token_file. + /// + /// example: "o&^uCtes4HPf0Vu@F20jQeeWE7" pub registration_token: Option<String>, + + /// Path to a file on the system that gets read for the registration token. + /// this config option takes precedence/priority over "registration_token". + /// + /// conduwuit must be able to access the file, and it must not be empty + /// + /// example: "/etc/conduwuit/.reg_token" + pub registration_token_file: Option<PathBuf>, + + /// Controls whether encrypted rooms and events are allowed. #[serde(default = "true_fn")] pub allow_encryption: bool, + + /// Controls whether federation is allowed or not. It is not recommended to + /// disable this after the fact due to potential federation breakage. #[serde(default = "true_fn")] pub allow_federation: bool, + #[serde(default)] pub federation_loopback: bool, + + /// Set this to true to require authentication on the normally + /// unauthenticated profile retrieval endpoints (GET) + /// "/_matrix/client/v3/profile/{userId}". + /// + /// This can prevent profile scraping. + #[serde(default)] + pub require_auth_for_profile_requests: bool, + + /// Set this to true to allow your server's public room directory to be + /// federated. Set this to false to protect against /publicRooms spiders, + /// but will forbid external users from viewing your server's public room + /// directory. If federation is disabled entirely (`allow_federation`), + /// this is inherently false. #[serde(default)] pub allow_public_room_directory_over_federation: bool, + + /// Set this to true to allow your server's public room directory to be + /// queried without client authentication (access token) through the Client + /// APIs. Set this to false to protect against /publicRooms spiders. #[serde(default)] pub allow_public_room_directory_without_auth: bool, + + /// allow guests/unauthenticated users to access TURN credentials + /// + /// this is the equivalent of Synapse's `turn_allow_guests` config option. + /// this allows any unauthenticated user to call the endpoint + /// `/_matrix/client/v3/voip/turnServer`. + /// + /// It is unlikely you need to enable this as all major clients support + /// authentication for this endpoint and prevents misuse of your TURN server + /// from potential bots. #[serde(default)] pub turn_allow_guests: bool, + + /// Set this to true to lock down your server's public room directory and + /// only allow admins to publish rooms to the room directory. Unpublishing + /// is still allowed by all users with this enabled. #[serde(default)] pub lockdown_public_room_directory: bool, + + /// Set this to true to allow federating device display names / allow + /// external users to see your device display name. If federation is + /// disabled entirely (`allow_federation`), this is inherently false. For + /// privacy reasons, this is best left disabled. #[serde(default)] pub allow_device_name_federation: bool, - #[serde(default = "true_fn")] - pub allow_profile_lookup_federation_requests: bool, + + /// Config option to allow or disallow incoming federation requests that + /// obtain the profiles of our local users from + /// `/_matrix/federation/v1/query/profile` + /// + /// Increases privacy of your local user's such as display names, but some + /// remote users may get a false "this user does not exist" error when they + /// try to invite you to a DM or room. Also can protect against profile + /// spiders. + /// + /// This is inherently false if `allow_federation` is disabled + #[serde(default = "true_fn", alias = "allow_profile_lookup_federation_requests")] + pub allow_inbound_profile_lookup_federation_requests: bool, + + /// controls whether standard users are allowed to create rooms. appservices + /// and admins are always allowed to create rooms #[serde(default = "true_fn")] pub allow_room_creation: bool, + + /// Set to false to disable users from joining or creating room versions + /// that aren't 100% officially supported by conduwuit. + /// + /// conduwuit officially supports room versions 6 - 11. + /// + /// conduwuit has slightly experimental (though works fine in practice) + /// support for versions 3 - 5 #[serde(default = "true_fn")] pub allow_unstable_room_versions: bool, + + /// default room version conduwuit will create rooms with. + /// + /// per spec, room version 10 is the default. + /// + /// default: 10 #[serde(default = "default_default_room_version")] pub default_room_version: RoomVersionId, + + // external structure; separate section #[serde(default)] pub well_known: WellKnownConfig, + #[serde(default)] pub allow_jaeger: bool, + + /// default: "info" #[serde(default = "default_jaeger_filter")] pub jaeger_filter: String, + + /// If the 'perf_measurements' compile-time feature is enabled, enables + /// collecting folded stack trace profile of tracing spans using + /// tracing_flame. The resulting profile can be visualized with inferno[1], + /// speedscope[2], or a number of other tools. + /// + /// [1]: https://github.com/jonhoo/inferno + /// [2]: www.speedscope.app #[serde(default)] pub tracing_flame: bool, + + /// default: "info" #[serde(default = "default_tracing_flame_filter")] pub tracing_flame_filter: String, + + /// default: "./tracing.folded" #[serde(default = "default_tracing_flame_output_path")] pub tracing_flame_output_path: String, + + /// Examples: + /// - No proxy (default): + /// proxy ="none" + /// + /// - For global proxy, create the section at the bottom of this file: + /// [global.proxy] + /// global = { url = "socks5h://localhost:9050" } + /// + /// - To proxy some domains: + /// [global.proxy] + /// [[global.proxy.by_domain]] + /// url = "socks5h://localhost:9050" + /// include = ["*.onion", "matrix.myspecial.onion"] + /// exclude = ["*.myspecial.onion"] + /// + /// Include vs. Exclude: + /// - If include is an empty list, it is assumed to be `["*"]`. + /// - If a domain matches both the exclude and include list, the proxy will + /// only be used if it was included because of a more specific rule than + /// it was excluded. In the above example, the proxy would be used for + /// `ordinary.onion`, `matrix.myspecial.onion`, but not + /// `hello.myspecial.onion`. + /// + /// default: "none" #[serde(default)] pub proxy: ProxyConfig, + pub jwt_secret: Option<String>, + + /// Servers listed here will be used to gather public keys of other servers + /// (notary trusted key servers). + /// + /// Currently, conduwuit doesn't support inbound batched key requests, so + /// this list should only contain other Synapse servers + /// + /// example: ["matrix.org", "constellatory.net", "tchncs.de"] + /// + /// default: ["matrix.org"] #[serde(default = "default_trusted_servers")] pub trusted_servers: Vec<OwnedServerName>, - #[serde(default = "true_fn")] + + /// Whether to query the servers listed in trusted_servers first or query + /// the origin server first. For best security, querying the origin server + /// first is advised to minimize the exposure to a compromised trusted + /// server. For maximum federation/join performance this can be set to true, + /// however other options exist to query trusted servers first under + /// specific high-load circumstances and should be evaluated before setting + /// this to true. + #[serde(default)] pub query_trusted_key_servers_first: bool, + + /// Whether to query the servers listed in trusted_servers first + /// specifically on room joins. This option limits the exposure to a + /// compromised trusted server to room joins only. The join operation + /// requires gathering keys from many origin servers which can cause + /// significant delays. Therefor this defaults to true to mitigate + /// unexpected delays out-of-the-box. The security-paranoid or those + /// willing to tolerate delays are advised to set this to false. Note that + /// setting query_trusted_key_servers_first to true causes this option to + /// be ignored. + #[serde(default = "true_fn")] + pub query_trusted_key_servers_first_on_join: bool, + + /// Only query trusted servers for keys and never the origin server. This is + /// intended for clusters or custom deployments using their trusted_servers + /// as forwarding-agents to cache and deduplicate requests. Notary servers + /// do not act as forwarding-agents by default, therefor do not enable this + /// unless you know exactly what you are doing. + #[serde(default)] + pub only_query_trusted_key_servers: bool, + + /// Maximum number of keys to request in each trusted server batch query. + /// + /// default: 1024 + #[serde(default = "default_trusted_server_batch_size")] + pub trusted_server_batch_size: usize, + + /// max log level for conduwuit. allows debug, info, warn, or error + /// see also: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + /// + /// **Caveat**: + /// For release builds, the tracing crate is configured to only implement + /// levels higher than error to avoid unnecessary overhead in the compiled + /// binary from trace macros. For debug builds, this restriction is not + /// applied. + /// + /// default: "info" #[serde(default = "default_log")] pub log: String, + + /// controls whether logs will be outputted with ANSI colours #[serde(default = "true_fn", alias = "log_colours")] pub log_colors: bool, + + /// configures the span events which will be outputted with the log + /// + /// default: "none" + #[serde(default = "default_log_span_events")] + pub log_span_events: String, + + /// OpenID token expiration/TTL in seconds + /// + /// These are the OpenID tokens that are primarily used for Matrix account + /// integrations (e.g. Vector Integrations in Element), *not* OIDC/OpenID + /// Connect/etc + /// + /// default: 3600 #[serde(default = "default_openid_token_ttl")] pub openid_token_ttl: u64, + + /// static TURN username to provide the client if not using a shared secret + /// ("turn_secret"), It is recommended to use a shared secret over static + /// credentials. #[serde(default)] pub turn_username: String, + + /// static TURN password to provide the client if not using a shared secret + /// ("turn_secret"). It is recommended to use a shared secret over static + /// credentials. #[serde(default)] pub turn_password: String, - #[serde(default = "Vec::new")] + + /// vector list of TURN URIs/servers to use + /// + /// replace "example.turn.uri" with your TURN domain, such as the coturn + /// "realm" config option. if using TURN over TLS, replace the URI prefix + /// "turn:" with "turns:" + /// + /// example: ["turn:example.turn.uri?transport=udp", + /// "turn:example.turn.uri?transport=tcp"] + /// + /// default: [] + #[serde(default)] pub turn_uris: Vec<String>, + + /// TURN secret to use for generating the HMAC-SHA1 hash apart of username + /// and password generation + /// + /// this is more secure, but if needed you can use traditional + /// static username/password credentials. #[serde(default)] pub turn_secret: String, + + /// TURN secret to use that's read from the file path specified + /// + /// this takes priority over "turn_secret" first, and falls back to + /// "turn_secret" if invalid or failed to open. + /// + /// example: "/etc/conduwuit/.turn_secret" pub turn_secret_file: Option<PathBuf>, + + /// TURN TTL in seconds + /// + /// default: 86400 #[serde(default = "default_turn_ttl")] pub turn_ttl: u64, + /// List/vector of room IDs or room aliases that conduwuit will make newly + /// registered users join. The rooms specified must be rooms that you + /// have joined at least once on the server, and must be public. + /// + /// example: ["#conduwuit:puppygock.gay", + /// "!eoIzvAvVwY23LPDay8:puppygock.gay"] + /// + /// default: [] #[serde(default = "Vec::new")] - pub auto_join_rooms: Vec<OwnedRoomId>, + pub auto_join_rooms: Vec<OwnedRoomOrAliasId>, + + /// Config option to automatically deactivate the account of any user who + /// attempts to join a: + /// - banned room + /// - forbidden room alias + /// - room alias or ID with a forbidden server name + /// + /// This may be useful if all your banned lists consist of toxic rooms or + /// servers that no good faith user would ever attempt to join, and + /// to automatically remediate the problem without any admin user + /// intervention. + /// + /// This will also make the user leave all rooms. Federation (e.g. remote + /// room invites) are ignored here. + /// + /// Defaults to false as rooms can be banned for non-moderation-related + /// reasons #[serde(default)] pub auto_deactivate_banned_room_attempts: bool, + /// RocksDB log level. This is not the same as conduwuit's log level. This + /// is the log level for the RocksDB engine/library which show up in your + /// database folder/path as `LOG` files. conduwuit will log RocksDB errors + /// as normal through tracing. + /// + /// default: "error" #[serde(default = "default_rocksdb_log_level")] pub rocksdb_log_level: String, + #[serde(default)] pub rocksdb_log_stderr: bool, + + /// Max RocksDB `LOG` file size before rotating in bytes. Defaults to 4MB in + /// bytes. + /// + /// default: 4194304 #[serde(default = "default_rocksdb_max_log_file_size")] pub rocksdb_max_log_file_size: usize, + + /// Time in seconds before RocksDB will forcibly rotate logs. + /// + /// default: 0 #[serde(default = "default_rocksdb_log_time_to_roll")] pub rocksdb_log_time_to_roll: usize, + + /// Set this to true to use RocksDB config options that are tailored to HDDs + /// (slower device storage) + /// + /// It is worth noting that by default, conduwuit will use RocksDB with + /// Direct IO enabled. *Generally* speaking this improves performance as it + /// bypasses buffered I/O (system page cache). However there is a potential + /// chance that Direct IO may cause issues with database operations if your + /// setup is uncommon. This has been observed with FUSE filesystems, and + /// possibly ZFS filesystem. RocksDB generally deals/corrects these issues + /// but it cannot account for all setups. If you experience any weird + /// RocksDB issues, try enabling this option as it turns off Direct IO and + /// feel free to report in the conduwuit Matrix room if this option fixes + /// your DB issues. + /// + /// See https://github.com/facebook/rocksdb/wiki/Direct-IO for more information. #[serde(default)] pub rocksdb_optimize_for_spinning_disks: bool, + + /// Enables direct-io to increase database performance via unbuffered I/O. + /// + /// See https://github.com/facebook/rocksdb/wiki/Direct-IO for more details about Direct IO and RocksDB. + /// + /// Set this option to false if the database resides on a filesystem which + /// does not support direct-io like FUSE, or any form of complex filesystem + /// setup such as possibly ZFS. #[serde(default = "true_fn")] pub rocksdb_direct_io: bool, + + /// Amount of threads that RocksDB will use for parallelism on database + /// operatons such as cleanup, sync, flush, compaction, etc. Set to 0 to use + /// all your logical threads. Defaults to your CPU logical thread count. + /// + /// default: 0 #[serde(default = "default_rocksdb_parallelism_threads")] pub rocksdb_parallelism_threads: usize, + + /// Maximum number of LOG files RocksDB will keep. This must *not* be set to + /// 0. It must be at least 1. Defaults to 3 as these are not very useful + /// unless troubleshooting/debugging a RocksDB bug. + /// + /// default: 3 #[serde(default = "default_rocksdb_max_log_files")] pub rocksdb_max_log_files: usize, + + /// Type of RocksDB database compression to use. + /// + /// Available options are "zstd", "zlib", "bz2", "lz4", or "none" + /// + /// It is best to use ZSTD as an overall good balance between + /// speed/performance, storage, IO amplification, and CPU usage. + /// For more performance but less compression (more storage used) and less + /// CPU usage, use LZ4. See https://github.com/facebook/rocksdb/wiki/Compression for more details. + /// + /// "none" will disable compression. + /// + /// default: "zstd" #[serde(default = "default_rocksdb_compression_algo")] pub rocksdb_compression_algo: String, + + /// Level of compression the specified compression algorithm for RocksDB to + /// use. + /// + /// Default is 32767, which is internally read by RocksDB as the + /// default magic number and translated to the library's default + /// compression level as they all differ. + /// See their `kDefaultCompressionLevel`. + /// + /// default: 32767 #[serde(default = "default_rocksdb_compression_level")] pub rocksdb_compression_level: i32, + + /// Level of compression the specified compression algorithm for the + /// bottommost level/data for RocksDB to use. Default is 32767, which is + /// internally read by RocksDB as the default magic number and translated + /// to the library's default compression level as they all differ. + /// See their `kDefaultCompressionLevel`. + /// + /// Since this is the bottommost level (generally old and least used data), + /// it may be desirable to have a very high compression level here as it's + /// lesss likely for this data to be used. Research your chosen compression + /// algorithm. + /// + /// default: 32767 #[serde(default = "default_rocksdb_bottommost_compression_level")] pub rocksdb_bottommost_compression_level: i32, + + /// Whether to enable RocksDB's "bottommost_compression". + /// + /// At the expense of more CPU usage, this will further compress the + /// database to reduce more storage. It is recommended to use ZSTD + /// compression with this for best compression results. This may be useful + /// if you're trying to reduce storage usage from the database. + /// + /// See https://github.com/facebook/rocksdb/wiki/Compression for more details. #[serde(default)] pub rocksdb_bottommost_compression: bool, + + /// Database recovery mode (for RocksDB WAL corruption) + /// + /// Use this option when the server reports corruption and refuses to start. + /// Set mode 2 (PointInTime) to cleanly recover from this corruption. The + /// server will continue from the last good state, several seconds or + /// minutes prior to the crash. Clients may have to run "clear-cache & + /// reload" to account for the rollback. Upon success, you may reset the + /// mode back to default and restart again. Please note in some cases the + /// corruption error may not be cleared for at least 30 minutes of + /// operation in PointInTime mode. + /// + /// As a very last ditch effort, if PointInTime does not fix or resolve + /// anything, you can try mode 3 (SkipAnyCorruptedRecord) but this will + /// leave the server in a potentially inconsistent state. + /// + /// The default mode 1 (TolerateCorruptedTailRecords) will automatically + /// drop the last entry in the database if corrupted during shutdown, but + /// nothing more. It is extraordinarily unlikely this will desynchronize + /// clients. To disable any form of silent rollback set mode 0 + /// (AbsoluteConsistency). + /// + /// The options are: + /// 0 = AbsoluteConsistency + /// 1 = TolerateCorruptedTailRecords (default) + /// 2 = PointInTime (use me if trying to recover) + /// 3 = SkipAnyCorruptedRecord (you now voided your Conduwuit warranty) + /// + /// See https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes for more information on these modes. + /// + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. + /// + /// default: 1 #[serde(default = "default_rocksdb_recovery_mode")] pub rocksdb_recovery_mode: u8, + + /// Database repair mode (for RocksDB SST corruption) + /// + /// Use this option when the server reports corruption while running or + /// panics. If the server refuses to start use the recovery mode options + /// first. Corruption errors containing the acronym 'SST' which occur after + /// startup will likely require this option. + /// + /// - Backing up your database directory is recommended prior to running the + /// repair. + /// - Disabling repair mode and restarting the server is recommended after + /// running the repair. + /// + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#database-corruption for more details on recovering a corrupt database. #[serde(default)] pub rocksdb_repair: bool, + #[serde(default)] pub rocksdb_read_only: bool, + + #[serde(default)] + pub rocksdb_secondary: bool, + + /// Enables idle CPU priority for compaction thread. This is not enabled by + /// default to prevent compaction from falling too far behind on busy + /// systems. #[serde(default)] pub rocksdb_compaction_prio_idle: bool, + + /// Enables idle IO priority for compaction thread. This prevents any + /// unexpected lag in the server's operation and is usually a good idea. + /// Enabled by default. #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, + + /// Config option to disable RocksDB compaction. You should never ever have + /// to disable this. If you for some reason find yourself needing to disable + /// this as part of troubleshooting or a bug, please reach out to us in the + /// conduwuit Matrix room with information and details. + /// + /// Disabling compaction will lead to a significantly bloated and + /// explosively large database, gradually poor performance, unnecessarily + /// excessive disk read/writes, and slower shutdowns and startups. #[serde(default = "true_fn")] pub rocksdb_compaction: bool, + + /// Level of statistics collection. Some admin commands to display database + /// statistics may require this option to be set. Database performance may + /// be impacted by higher settings. + /// + /// Option is a number ranging from 0 to 6: + /// 0 = No statistics. + /// 1 = No statistics in release mode (default). + /// 2 to 3 = Statistics with no performance impact. + /// 3 to 5 = Statistics with possible performance impact. + /// 6 = All statistics. + /// + /// default: 1 #[serde(default = "default_rocksdb_stats_level")] pub rocksdb_stats_level: u8, + /// This is a password that can be configured that will let you login to the + /// server bot account (currently `@conduit`) for emergency troubleshooting + /// purposes such as recovering/recreating your admin room, or inviting + /// yourself back. + /// + /// See https://conduwuit.puppyirl.gay/troubleshooting.html#lost-access-to-admin-room for other ways to get back into your admin room. + /// + /// Once this password is unset, all sessions will be logged out for + /// security purposes. + /// + /// example: "F670$2CP@Hw8mG7RY1$%!#Ic7YA" pub emergency_password: Option<String>, + /// default: "/_matrix/push/v1/notify" #[serde(default = "default_notification_push_path")] pub notification_push_path: String, + /// Config option to control local (your server only) presence + /// updates/requests. Note that presence on conduwuit is + /// very fast unlike Synapse's. If using outgoing presence, this MUST be + /// enabled. #[serde(default = "true_fn")] pub allow_local_presence: bool, + + /// Config option to control incoming federated presence updates/requests. + /// + /// This option receives presence updates from other + /// servers, but does not send any unless `allow_outgoing_presence` is true. + /// Note that presence on conduwuit is very fast unlike Synapse's. #[serde(default = "true_fn")] pub allow_incoming_presence: bool, + + /// Config option to control outgoing presence updates/requests. + /// + /// This option sends presence updates to other servers, but does not + /// receive any unless `allow_incoming_presence` is true. + /// Note that presence on conduwuit is very fast unlike Synapse's. + /// If using outgoing presence, you MUST enable `allow_local_presence` as + /// well. #[serde(default = "true_fn")] pub allow_outgoing_presence: bool, + + /// Config option to control how many seconds before presence updates that + /// you are idle. Defaults to 5 minutes. + /// + /// default: 300 #[serde(default = "default_presence_idle_timeout_s")] pub presence_idle_timeout_s: u64, + + /// Config option to control how many seconds before presence updates that + /// you are offline. Defaults to 30 minutes. + /// + /// default: 1800 #[serde(default = "default_presence_offline_timeout_s")] pub presence_offline_timeout_s: u64, + + /// Config option to enable the presence idle timer for remote users. + /// Disabling is offered as an optimization for servers participating in + /// many large rooms or when resources are limited. Disabling it may cause + /// incorrect presence states (i.e. stuck online) to be seen for some + /// remote users. #[serde(default = "true_fn")] pub presence_timeout_remote_users: bool, + /// Config option to control whether we should receive remote incoming read + /// receipts. #[serde(default = "true_fn")] pub allow_incoming_read_receipts: bool, + + /// Config option to control whether we should send read receipts to remote + /// servers. #[serde(default = "true_fn")] pub allow_outgoing_read_receipts: bool, + /// Config option to control outgoing typing updates to federation. #[serde(default = "true_fn")] pub allow_outgoing_typing: bool, + + /// Config option to control incoming typing updates from federation. #[serde(default = "true_fn")] pub allow_incoming_typing: bool, + + /// Config option to control maximum time federation user can indicate + /// typing. + /// + /// default: 30 #[serde(default = "default_typing_federation_timeout_s")] pub typing_federation_timeout_s: u64, + + /// Config option to control minimum time local client can indicate typing. + /// This does not override a client's request to stop typing. It only + /// enforces a minimum value in case of no stop request. + /// + /// default: 15 #[serde(default = "default_typing_client_timeout_min_s")] pub typing_client_timeout_min_s: u64, + + /// Config option to control maximum time local client can indicate typing. + /// + /// default: 45 #[serde(default = "default_typing_client_timeout_max_s")] pub typing_client_timeout_max_s: u64, + /// Set this to true for conduwuit to compress HTTP response bodies using + /// zstd. This option does nothing if conduwuit was not built with + /// `zstd_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH + /// before deciding to enable this. #[serde(default)] pub zstd_compression: bool, + + /// Set this to true for conduwuit to compress HTTP response bodies using + /// gzip. This option does nothing if conduwuit was not built with + /// `gzip_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before + /// deciding to enable this. + /// + /// If you are in a large amount of rooms, you may find that enabling this + /// is necessary to reduce the significantly large response bodies. #[serde(default)] pub gzip_compression: bool, + + /// Set this to true for conduwuit to compress HTTP response bodies using + /// brotli. This option does nothing if conduwuit was not built with + /// `brotli_compression` feature. Please be aware that enabling HTTP + /// compression may weaken TLS. Most users should not need to enable this. + /// See https://breachattack.com/ and https://wikipedia.org/wiki/BREACH before + /// deciding to enable this. #[serde(default)] pub brotli_compression: bool, + /// Set to true to allow user type "guest" registrations. Some clients like + /// Element attempt to register guest users automatically. #[serde(default)] pub allow_guest_registration: bool, + + /// Set to true to log guest registrations in the admin room. Note that + /// these may be noisy or unnecessary if you're a public homeserver. #[serde(default)] pub log_guest_registrations: bool, + + /// Set to true to allow guest registrations/users to auto join any rooms + /// specified in `auto_join_rooms`. #[serde(default)] pub allow_guests_auto_join_rooms: bool, + /// Config option to control whether the legacy unauthenticated Matrix media + /// repository endpoints will be enabled. These endpoints consist of: + /// - /_matrix/media/*/config + /// - /_matrix/media/*/upload + /// - /_matrix/media/*/preview_url + /// - /_matrix/media/*/download/* + /// - /_matrix/media/*/thumbnail/* + /// + /// The authenticated equivalent endpoints are always enabled. + /// + /// Defaults to true for now, but this is highly subject to change, likely + /// in the next release. #[serde(default = "true_fn")] pub allow_legacy_media: bool, + #[serde(default = "true_fn")] pub freeze_legacy_media: bool, + + /// Checks consistency of the media directory at startup: + /// 1. When `media_compat_file_link` is enbled, this check will upgrade + /// media when switching back and forth between Conduit and conduwuit. + /// Both options must be enabled to handle this. + /// 2. When media is deleted from the directory, this check will also delete + /// its database entry. + /// + /// If none of these checks apply to your use cases, and your media + /// directory is significantly large setting this to false may reduce + /// startup time. #[serde(default = "true_fn")] pub media_startup_check: bool, + + /// Enable backward-compatibility with Conduit's media directory by creating + /// symlinks of media. This option is only necessary if you plan on using + /// Conduit again. Otherwise setting this to false reduces filesystem + /// clutter and overhead for managing these symlinks in the directory. This + /// is now disabled by default. You may still return to upstream Conduit + /// but you have to run conduwuit at least once with this set to true and + /// allow the media_startup_check to take place before shutting + /// down to return to Conduit. #[serde(default)] pub media_compat_file_link: bool, + + /// Prunes missing media from the database as part of the media startup + /// checks. This means if you delete files from the media directory the + /// corresponding entries will be removed from the database. This is + /// disabled by default because if the media directory is accidentally moved + /// or inaccessible, the metadata entries in the database will be lost with + /// sadness. #[serde(default)] pub prune_missing_media: bool, - #[serde(default = "Vec::new")] - pub prevent_media_downloads_from: Vec<OwnedServerName>, - - #[serde(default = "Vec::new")] - pub forbidden_remote_server_names: Vec<OwnedServerName>, - #[serde(default = "Vec::new")] - pub forbidden_remote_room_directory_server_names: Vec<OwnedServerName>, + /// Vector list of servers that conduwuit will refuse to download remote + /// media from. + /// + /// default: [] + #[serde(default)] + pub prevent_media_downloads_from: HashSet<OwnedServerName>, + + /// List of forbidden server names that we will block incoming AND outgoing + /// federation with, and block client room joins / remote user invites. + /// + /// This check is applied on the room ID, room alias, sender server name, + /// sender user's server name, inbound federation X-Matrix origin, and + /// outbound federation handler. + /// + /// Basically "global" ACLs. + /// + /// default: [] + #[serde(default)] + pub forbidden_remote_server_names: HashSet<OwnedServerName>, + + /// List of forbidden server names that we will block all outgoing federated + /// room directory requests for. Useful for preventing our users from + /// wandering into bad servers or spaces. + /// + /// default: [] + #[serde(default = "HashSet::new")] + pub forbidden_remote_room_directory_server_names: HashSet<OwnedServerName>, + + /// Vector list of IPv4 and IPv6 CIDR ranges / subnets *in quotes* that you + /// do not want conduwuit to send outbound requests to. Defaults to + /// RFC1918, unroutable, loopback, multicast, and testnet addresses for + /// security. + /// + /// Please be aware that this is *not* a guarantee. You should be using a + /// firewall with zones as doing this on the application layer may have + /// bypasses. + /// + /// Currently this does not account for proxies in use like Synapse does. + /// + /// To disable, set this to be an empty vector (`[]`). + /// + /// default: ["127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", + /// "192.168.0.0/16", "100.64.0.0/10", "192.0.0.0/24", "169.254.0.0/16", + /// "192.88.99.0/24", "198.18.0.0/15", "192.0.2.0/24", "198.51.100.0/24", + /// "203.0.113.0/24", "224.0.0.0/4", "::1/128", "fe80::/10", "fc00::/7", + /// "2001:db8::/32", "ff00::/8", "fec0::/10"] #[serde(default = "default_ip_range_denylist")] pub ip_range_denylist: Vec<String>, - #[serde(default = "Vec::new")] + /// Vector list of domains allowed to send requests to for URL previews. + /// Defaults to none. Note: this is a *contains* match, not an explicit + /// match. Putting "google.com" will match "https://google.com" and + /// "http://mymaliciousdomainexamplegoogle.com" Setting this to "*" will + /// allow all URL previews. Please note that this opens up significant + /// attack surface to your server, you are expected to be aware of the + /// risks by doing so. + /// + /// default: [] + #[serde(default)] pub url_preview_domain_contains_allowlist: Vec<String>, - #[serde(default = "Vec::new")] + + /// Vector list of explicit domains allowed to send requests to for URL + /// previews. Defaults to none. Note: This is an *explicit* match, not a + /// contains match. Putting "google.com" will match "https://google.com", + /// "http://google.com", but not + /// "https://mymaliciousdomainexamplegoogle.com". Setting this to "*" will + /// allow all URL previews. Please note that this opens up significant + /// attack surface to your server, you are expected to be aware of the + /// risks by doing so. + /// + /// default: [] + #[serde(default)] pub url_preview_domain_explicit_allowlist: Vec<String>, - #[serde(default = "Vec::new")] + + /// Vector list of explicit domains not allowed to send requests to for URL + /// previews. Defaults to none. Note: This is an *explicit* match, not a + /// contains match. Putting "google.com" will match "https://google.com", + /// "http://google.com", but not + /// "https://mymaliciousdomainexamplegoogle.com". The denylist is checked + /// first before allowlist. Setting this to "*" will not do anything. + /// + /// default: [] + #[serde(default)] pub url_preview_domain_explicit_denylist: Vec<String>, - #[serde(default = "Vec::new")] + + /// Vector list of URLs allowed to send requests to for URL previews. + /// Defaults to none. Note that this is a *contains* match, not an + /// explicit match. Putting "google.com" will match + /// "https://google.com/", + /// "https://google.com/url?q=https://mymaliciousdomainexample.com", and + /// "https://mymaliciousdomainexample.com/hi/google.com" Setting this to + /// "*" will allow all URL previews. Please note that this opens up + /// significant attack surface to your server, you are expected to be + /// aware of the risks by doing so. + /// + /// default: [] + #[serde(default)] pub url_preview_url_contains_allowlist: Vec<String>, + + /// Maximum amount of bytes allowed in a URL preview body size when + /// spidering. Defaults to 384KB in bytes. + /// + /// default: 384000 #[serde(default = "default_url_preview_max_spider_size")] pub url_preview_max_spider_size: usize, + + /// Option to decide whether you would like to run the domain allowlist + /// checks (contains and explicit) on the root domain or not. Does not apply + /// to URL contains allowlist. Defaults to false. + /// + /// Example usecase: If this is + /// enabled and you have "wikipedia.org" allowed in the explicit and/or + /// contains domain allowlist, it will allow all subdomains under + /// "wikipedia.org" such as "en.m.wikipedia.org" as the root domain is + /// checked and matched. Useful if the domain contains allowlist is still + /// too broad for you but you still want to allow all the subdomains under a + /// root domain. #[serde(default)] pub url_preview_check_root_domain: bool, - #[serde(default = "RegexSet::empty")] + /// List of forbidden room aliases and room IDs as strings of regex + /// patterns. + /// + /// Regex can be used or explicit contains matches can be done by + /// just specifying the words (see example). + /// + /// This is checked upon room alias creation, custom room ID creation if + /// used, and startup as warnings if any room aliases in your database have + /// a forbidden room alias/ID. + /// + /// example: ["19dollarfortnitecards", "b[4a]droom"] + /// + /// default: [] + #[serde(default)] #[serde(with = "serde_regex")] pub forbidden_alias_names: RegexSet, - #[serde(default = "RegexSet::empty")] + /// List of forbidden username patterns/strings. + /// + /// Regex can be used or explicit contains matches can be done by just + /// specifying the words (see example). + /// + /// This is checked upon username availability check, registration, and + /// startup as warnings if any local users in your database have a forbidden + /// username. + /// + /// example: ["administrator", "b[a4]dusernam[3e]"] + /// + /// default: [] + #[serde(default)] #[serde(with = "serde_regex")] pub forbidden_usernames: RegexSet, + /// Retry failed and incomplete messages to remote servers immediately upon + /// startup. This is called bursting. If this is disabled, said messages + /// may not be delivered until more messages are queued for that server. Do + /// not change this option unless server resources are extremely limited or + /// the scale of the server's deployment is huge. Do not disable this + /// unless you know what you are doing. #[serde(default = "true_fn")] pub startup_netburst: bool, + + /// messages are dropped and not reattempted. The `startup_netburst` option + /// must be enabled for this value to have any effect. Do not change this + /// value unless you know what you are doing. Set this value to -1 to + /// reattempt every message without trimming the queues; this may consume + /// significant disk. Set this value to 0 to drop all messages without any + /// attempt at redelivery. + /// + /// default: 50 #[serde(default = "default_startup_netburst_keep")] pub startup_netburst_keep: i64, + /// controls whether non-admin local users are forbidden from sending room + /// invites (local and remote), and if non-admin users can receive remote + /// room invites. admins are always allowed to send and receive all room + /// invites. #[serde(default)] pub block_non_admin_invites: bool, + + /// Allows admins to enter commands in rooms other than "#admins" (admin + /// room) by prefixing your message with "\!admin" or "\\!admin" followed + /// up a normal conduwuit admin command. The reply will be publicly visible + /// to the room, originating from the sender. + /// + /// example: \\!admin debug ping puppygock.gay #[serde(default = "true_fn")] pub admin_escape_commands: bool, + + /// Controls whether the conduwuit admin room console / CLI will immediately + /// activate on startup. This option can also be enabled with `--console` + /// conduwuit argument. #[serde(default)] pub admin_console_automatic: bool, + + /// Controls what admin commands will be executed on startup. This is a + /// vector list of strings of admin commands to run. + /// + /// + /// This option can also be configured with the `--execute` conduwuit + /// argument and can take standard shell commands and environment variables + /// + /// Such example could be: `./conduwuit --execute "server admin-notice + /// conduwuit has started up at $(date)"` + /// + /// example: admin_execute = ["debug ping puppygock.gay", "debug echo hi"]` + /// + /// default: [] #[serde(default)] pub admin_execute: Vec<String>, + + /// Controls whether conduwuit should error and fail to start if an admin + /// execute command (`--execute` / `admin_execute`) fails. #[serde(default)] pub admin_execute_errors_ignore: bool, + + /// Controls the max log level for admin command log captures (logs + /// generated from running admin commands). Defaults to "info" on release + /// builds, else "debug" on debug builds. + /// + /// default: "info" #[serde(default = "default_admin_log_capture")] pub admin_log_capture: String, + + /// The default room tag to apply on the admin room. + /// + /// On some clients like Element, the room tag "m.server_notice" is a + /// special pinned room at the very bottom of your room list. The conduwuit + /// admin room can be pinned here so you always have an easy-to-access + /// shortcut dedicated to your admin room. + /// + /// default: "m.server_notice" #[serde(default = "default_admin_room_tag")] pub admin_room_tag: String, + /// Sentry.io crash/panic reporting, performance monitoring/metrics, etc. + /// This is NOT enabled by default. conduwuit's default Sentry reporting + /// endpoint is o4506996327251968.ingest.us.sentry.io #[serde(default)] pub sentry: bool, + + /// Sentry reporting URL if a custom one is desired + /// + /// default: "https://fe2eb4536aa04949e28eff3128d64757@o4506996327251968.ingest.us.sentry.io/4506996334657536" #[serde(default = "default_sentry_endpoint")] pub sentry_endpoint: Option<Url>, + + /// Report your conduwuit server_name in Sentry.io crash reports and metrics #[serde(default)] pub sentry_send_server_name: bool, + + /// Performance monitoring/tracing sample rate for Sentry.io + /// + /// Note that too high values may impact performance, and can be disabled by + /// setting it to 0.0 (0%) This value is read as a percentage to Sentry, + /// represented as a decimal. Defaults to 15% of traces (0.15) + /// + /// default: 0.15 #[serde(default = "default_sentry_traces_sample_rate")] pub sentry_traces_sample_rate: f32, + + /// Whether to attach a stacktrace to Sentry reports. #[serde(default)] pub sentry_attach_stacktrace: bool, + + /// Send panics to sentry. This is true by default, but sentry has to be + /// enabled. The global "sentry" config option must be enabled to send any + /// data. #[serde(default = "true_fn")] pub sentry_send_panic: bool, + + /// Send errors to sentry. This is true by default, but sentry has to be + /// enabled. This option is only effective in release-mode; forced to false + /// in debug-mode. #[serde(default = "true_fn")] pub sentry_send_error: bool, + + /// Controls the tracing log level for Sentry to send things like + /// breadcrumbs and transactions + /// + /// default: "info" #[serde(default = "default_sentry_filter")] pub sentry_filter: String, + /// Enable the tokio-console. This option is only relevant to developers. + /// See https://conduwuit.puppyirl.gay/development.html#debugging-with-tokio-console for more information. #[serde(default)] pub tokio_console: bool, @@ -390,18 +1497,35 @@ pub struct Config { } #[derive(Clone, Debug, Deserialize)] +#[config_example_generator(filename = "conduwuit-example.toml", section = "global.tls")] pub struct TlsConfig { + /// Path to a valid TLS certificate file. + /// + /// example: "/path/to/my/certificate.crt" pub certs: String, + /// Path to a valid TLS certificate private key. + /// + /// example: "/path/to/my/certificate.key" pub key: String, - #[serde(default)] /// Whether to listen and allow for HTTP and HTTPS connections (insecure!) + #[serde(default)] pub dual_protocol: bool, } #[derive(Clone, Debug, Deserialize, Default)] +#[config_example_generator(filename = "conduwuit-example.toml", section = "global.well_known")] pub struct WellKnownConfig { - pub client: Option<Url>, + /// The server base domain of the URL with a specific port that the server + /// well-known file will serve. This should contain a port at the end, and + /// should not be a URL. + /// + /// example: "matrix.example.com:443" pub server: Option<OwnedServerName>, + /// The server URL that the client well-known file will serve. This should + /// not contain a port, and should just be a valid HTTPS URL. + /// + /// example: "<https://matrix.example.com>" + pub client: Option<Url>, pub support_page: Option<Url>, pub support_role: Option<ContactRole>, pub support_email: Option<String>, @@ -436,34 +1560,26 @@ struct ListeningAddr { impl Config { /// Pre-initialize config - pub fn load(paths: &Option<Vec<PathBuf>>) -> Result<Figment> { - let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { - Figment::new().merge(Toml::file(config_file_env).nested()) - } else if let Some(config_file_arg) = Env::var("CONDUWUIT_CONFIG") { - Figment::new().merge(Toml::file(config_file_arg).nested()) - } else if let Some(config_file_args) = paths { - let mut figment = Figment::new(); - - for config in config_file_args { - figment = figment.merge(Toml::file(config).nested()); - } + pub fn load(paths: Option<&[PathBuf]>) -> Result<Figment> { + let paths_files = paths.into_iter().flatten().map(Toml::file); - figment - } else { - Figment::new() - }; + let envs = [Env::var("CONDUIT_CONFIG"), Env::var("CONDUWUIT_CONFIG")]; + let envs_files = envs.into_iter().flatten().map(Toml::file); - Ok(raw_config + let config = envs_files + .chain(paths_files) + .fold(Figment::new(), |config, file| config.merge(file.nested())) .merge(Env::prefixed("CONDUIT_").global().split("__")) - .merge(Env::prefixed("CONDUWUIT_").global().split("__"))) + .merge(Env::prefixed("CONDUWUIT_").global().split("__")); + + Ok(config) } /// Finalize config pub fn new(raw_config: &Figment) -> Result<Self> { - let config = match raw_config.extract::<Self>() { - Err(e) => return Err!("There was a problem with your configuration file: {e}"), - Ok(config) => config, - }; + let config = raw_config + .extract::<Self>() + .map_err(|e| err!("There was a problem with your configuration file: {e}"))?; // don't start if we're listening on both UNIX sockets and TCP at same time check::is_dual_listening(raw_config)?; @@ -506,13 +1622,12 @@ pub fn check(&self) -> Result<(), Error> { check(self) } impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Active config values:\n\n").expect("wrote line to formatter stream"); + writeln!(f, "Active config values:\n").expect("wrote line to formatter stream"); let mut line = |key: &str, val: &str| { writeln!(f, "{key}: {val}").expect("wrote line to formatter stream"); }; line("Server name", self.server_name.host()); - line("Database backend", &self.database_backend); line("Database path", &self.database_path.to_string_lossy()); line( "Database backup path", @@ -570,12 +1685,20 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { line("Allow registration", &self.allow_registration.to_string()); line( "Registration token", - if self.registration_token.is_some() { - "set" + if self.registration_token.is_none() && self.registration_token_file.is_none() && self.allow_registration { + "not set (âš ï¸ open registration!)" + } else if self.registration_token.is_none() && self.registration_token_file.is_none() { + "not set" } else { - "not set (open registration!)" + "set" }, ); + line( + "Registration token file path", + self.registration_token_file + .as_ref() + .map_or("", |path| path.to_str().unwrap_or_default()), + ); line( "Allow guest registration (inherently false if allow registration is false)", &self.allow_guest_registration.to_string(), @@ -592,6 +1715,10 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { line("Allow encryption", &self.allow_encryption.to_string()); line("Allow federation", &self.allow_federation.to_string()); line("Federation loopback", &self.federation_loopback.to_string()); + line( + "Require authentication for profile requests", + &self.require_auth_for_profile_requests.to_string(), + ); line( "Allow incoming federated presence requests (updates)", &self.allow_incoming_presence.to_string(), @@ -639,7 +1766,9 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { line("Allow device name federation", &self.allow_device_name_federation.to_string()); line( "Allow incoming profile lookup federation requests", - &self.allow_profile_lookup_federation_requests.to_string(), + &self + .allow_inbound_profile_lookup_federation_requests + .to_string(), ); line( "Auto deactivate banned room join attempts", @@ -674,10 +1803,6 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { .map(|server| server.host()) .join(", "), ); - line( - "Query Trusted Key Servers First", - &self.query_trusted_key_servers_first.to_string(), - ); line("OpenID Token TTL", &self.openid_token_ttl.to_string()); line( "TURN username", @@ -752,6 +1877,7 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { line("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()); line("RocksDB Repair Mode", &self.rocksdb_repair.to_string()); line("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()); + line("RocksDB Secondary Mode", &self.rocksdb_secondary.to_string()); line( "RocksDB Compaction Idle Priority", &self.rocksdb_compaction_prio_idle.to_string(), @@ -902,8 +2028,6 @@ fn default_unix_socket_perms() -> u32 { 660 } fn default_database_backups_to_keep() -> i16 { 1 } -fn default_database_backend() -> String { "rocksdb".to_owned() } - fn default_db_cache_capacity_mb() -> f64 { 128.0 + parallelism_scaled_f64(64.0) } fn default_pdu_cache_capacity() -> u32 { parallelism_scaled_u32(10_000).saturating_add(100_000) } @@ -958,7 +2082,7 @@ fn default_well_known_conn_timeout() -> u64 { 6 } fn default_well_known_timeout() -> u64 { 10 } -fn default_federation_timeout() -> u64 { 300 } +fn default_federation_timeout() -> u64 { 25 } fn default_federation_idle_timeout() -> u64 { 25 } @@ -976,7 +2100,7 @@ fn default_appservice_idle_timeout() -> u64 { 300 } fn default_pusher_idle_timeout() -> u64 { 15 } -fn default_max_fetch_prev_events() -> u16 { 100_u16 } +fn default_max_fetch_prev_events() -> u16 { 192_u16 } fn default_tracing_flame_filter() -> String { cfg!(debug_assertions) @@ -1005,6 +2129,9 @@ pub fn default_log() -> String { .to_owned() } +#[must_use] +pub fn default_log_span_events() -> String { "none".into() } + fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() } fn default_openid_token_ttl() -> u64 { 60 * 60 } @@ -1054,6 +2181,7 @@ fn default_rocksdb_stats_level() -> u8 { 1 } // I know, it's a great name #[must_use] +#[inline] pub fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } fn default_ip_range_denylist() -> Vec<String> { @@ -1114,3 +2242,5 @@ fn parallelism_scaled_u32(val: u32) -> u32 { } fn parallelism_scaled(val: usize) -> usize { val.saturating_mul(sys::available_parallelism()) } + +fn default_trusted_server_batch_size() -> usize { 256 } diff --git a/src/core/debug.rs b/src/core/debug.rs index 844445d5390a1a03c9ce865f410c9c027e691ca3..f7420784e56d3ca23c982eefcd8e988f6c4d343f 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,10 +1,12 @@ +#![allow(clippy::disallowed_macros)] + use std::{any::Any, panic}; -/// Export debug proc_macros +// Export debug proc_macros pub use conduit_macros::recursion_depth; -/// Export all of the ancillary tools from here as well. -pub use crate::utils::debug::*; +// Export all of the ancillary tools from here as well. +pub use crate::{result::DebugInspect, utils::debug::*}; /// Log event at given level in debug-mode (when debug-assertions are enabled). /// In release-mode it becomes DEBUG level, and possibly subject to elision. @@ -86,11 +88,11 @@ pub fn panic_str(p: &Box<dyn Any + Send>) -> &'static str { p.downcast_ref::<&st #[inline(always)] #[must_use] -pub fn rttype_name<T>(_: &T) -> &'static str { type_name::<T>() } +pub fn rttype_name<T: ?Sized>(_: &T) -> &'static str { type_name::<T>() } #[inline(always)] #[must_use] -pub fn type_name<T>() -> &'static str { std::any::type_name::<T>() } +pub fn type_name<T: ?Sized>() -> &'static str { std::any::type_name::<T>() } #[must_use] #[inline] diff --git a/src/core/error/err.rs b/src/core/error/err.rs index b3d0240ed20c6aa6dd4a8ff7eff5f7742d0e2b7c..a24441e0025d0e7cfa50fab92b9af9b8dc17db66 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -44,34 +44,34 @@ macro_rules! err { (Request(Forbidden($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request(Forbidden($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::ruma::api::client::error::ErrorKind::forbidden(), $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; (Request($variant:ident($level:ident!($($args:tt)+)))) => {{ let mut buf = String::new(); $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::err_log!(buf, $level, $($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }}; (Request($variant:ident($($args:tt)+))) => { $crate::error::Error::Request( - ::ruma::api::client::error::ErrorKind::$variant, + $crate::ruma::api::client::error::ErrorKind::$variant, $crate::format_maybe!($($args)+), - ::http::StatusCode::BAD_REQUEST + $crate::http::StatusCode::BAD_REQUEST ) }; @@ -85,6 +85,10 @@ macro_rules! err { $crate::error::Error::$variant($crate::err_log!(buf, $level, $($args)+)) }}; + ($variant:ident($($args:ident),+)) => { + $crate::error::Error::$variant($($args),+) + }; + ($variant:ident($($args:tt)+)) => { $crate::error::Error::$variant($crate::format_maybe!($($args)+)) }; @@ -107,12 +111,8 @@ macro_rules! err { #[macro_export] macro_rules! err_log { ($out:ident, $level:ident, $($fields:tt)+) => {{ - use std::{fmt, fmt::Write}; - - use ::tracing::{ - callsite, callsite2, level_enabled, metadata, valueset, Callsite, Event, __macro_support, - __tracing_log, - field::{Field, ValueSet, Visit}, + use $crate::tracing::{ + callsite, callsite2, metadata, valueset, Callsite, Level, }; @@ -130,33 +130,7 @@ macro_rules! err_log { fields: $($fields)+, }; - let visit = &mut |vs: ValueSet<'_>| { - struct Visitor<'a>(&'a mut String); - impl Visit for Visitor<'_> { - fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { - if field.name() == "message" { - write!(self.0, "{:?}", val).expect("stream error"); - } else { - write!(self.0, " {}={:?}", field.name(), val).expect("stream error"); - } - } - } - - let meta = __CALLSITE.metadata(); - let enabled = level_enabled!(LEVEL) && { - let interest = __CALLSITE.interest(); - !interest.is_never() && __macro_support::__is_enabled(meta, interest) - }; - - if enabled { - Event::dispatch(meta, &vs); - } - - __tracing_log!(LEVEL, __CALLSITE, &vs); - vs.record(&mut Visitor(&mut $out)); - }; - - (visit)(valueset!(__CALLSITE.metadata().fields(), $($fields)+)); + ($crate::error::visit)(&mut $out, LEVEL, &__CALLSITE, &mut valueset!(__CALLSITE.metadata().fields(), $($fields)+)); ($out).into() }} } @@ -165,25 +139,62 @@ fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { macro_rules! err_lev { (debug_warn) => { if $crate::debug::logging() { - ::tracing::Level::WARN + $crate::tracing::Level::WARN } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (debug_error) => { if $crate::debug::logging() { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR } else { - ::tracing::Level::DEBUG + $crate::tracing::Level::DEBUG } }; (warn) => { - ::tracing::Level::WARN + $crate::tracing::Level::WARN }; (error) => { - ::tracing::Level::ERROR + $crate::tracing::Level::ERROR }; } + +use std::{fmt, fmt::Write}; + +use tracing::{ + level_enabled, Callsite, Event, __macro_support, __tracing_log, + callsite::DefaultCallsite, + field::{Field, ValueSet, Visit}, + Level, +}; + +struct Visitor<'a>(&'a mut String); + +impl Visit for Visitor<'_> { + #[inline] + fn record_debug(&mut self, field: &Field, val: &dyn fmt::Debug) { + if field.name() == "message" { + write!(self.0, "{val:?}").expect("stream error"); + } else { + write!(self.0, " {}={val:?}", field.name()).expect("stream error"); + } + } +} + +pub fn visit(out: &mut String, level: Level, __callsite: &'static DefaultCallsite, vs: &mut ValueSet<'_>) { + let meta = __callsite.metadata(); + let enabled = level_enabled!(level) && { + let interest = __callsite.interest(); + !interest.is_never() && __macro_support::__is_enabled(meta, interest) + }; + + if enabled { + Event::dispatch(meta, vs); + } + + __tracing_log!(level, __callsite, vs); + vs.record(&mut Visitor(out)); +} diff --git a/src/core/error/log.rs b/src/core/error/log.rs index c272bf730c42ede2dd18d18e313546bdb18bf241..60bd7014061344040941b7e54f318695eff54ef2 100644 --- a/src/core/error/log.rs +++ b/src/core/error/log.rs @@ -1,7 +1,8 @@ use std::{convert::Infallible, fmt}; +use tracing::Level; + use super::Error; -use crate::{debug_error, error}; #[inline] pub fn else_log<T, E>(error: E) -> Result<T, Infallible> @@ -64,11 +65,33 @@ pub fn map_debug_log<E>(error: E) -> Error } #[inline] -pub fn inspect_log<E: fmt::Display>(error: &E) { - error!("{error}"); +pub fn inspect_log<E: fmt::Display>(error: &E) { inspect_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_debug_log<E: fmt::Debug>(error: &E) { inspect_debug_log_level(error, Level::ERROR); } + +#[inline] +pub fn inspect_log_level<E: fmt::Display>(error: &E, level: Level) { + use crate::{debug, error, info, trace, warn}; + + match level { + Level::ERROR => error!("{error}"), + Level::WARN => warn!("{error}"), + Level::INFO => info!("{error}"), + Level::DEBUG => debug!("{error}"), + Level::TRACE => trace!("{error}"), + } } #[inline] -pub fn inspect_debug_log<E: fmt::Debug>(error: &E) { - debug_error!("{error:?}"); +pub fn inspect_debug_log_level<E: fmt::Debug>(error: &E, level: Level) { + use crate::{debug, debug_error, debug_info, debug_warn, trace}; + + match level { + Level::ERROR => debug_error!("{error:?}"), + Level::WARN => debug_warn!("{error:?}"), + Level::INFO => debug_info!("{error:?}"), + Level::DEBUG => debug!("{error:?}"), + Level::TRACE => trace!("{error:?}"), + } } diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 92dbdfe3bb9f35f062eebb58c2c1852283b243fc..35bf98009f8552d06d160eff9f5714848a942b5d 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -4,10 +4,9 @@ mod response; mod serde; -use std::{any::Any, borrow::Cow, convert::Infallible, fmt}; +use std::{any::Any, borrow::Cow, convert::Infallible, fmt, sync::PoisonError}; -pub use self::log::*; -use crate::error; +pub use self::{err::visit, log::*}; #[derive(thiserror::Error)] pub enum Error { @@ -59,6 +58,8 @@ pub enum Error { JsTryFromInt(#[from] ruma::JsTryFromIntError), // js_int re-export #[error(transparent)] Path(#[from] axum::extract::rejection::PathRejection), + #[error("Mutex poisoned: {0}")] + Poison(Cow<'static, str>), #[error("Regex error: {0}")] Regex(#[from] regex::Error), #[error("Request error: {0}")] @@ -75,6 +76,8 @@ pub enum Error { TracingFilter(#[from] tracing_subscriber::filter::ParseError), #[error("Tracing reload error: {0}")] TracingReload(#[from] tracing_subscriber::reload::Error), + #[error(transparent)] + Yaml(#[from] serde_yaml::Error), // ruma/conduwuit #[error("Arithmetic operation failed: {0}")] @@ -83,10 +86,12 @@ pub enum Error { BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove #[error("{0}")] BadServerResponse(Cow<'static, str>), + #[error(transparent)] + CanonicalJson(#[from] ruma::CanonicalJsonError), #[error("There was a problem with the '{0}' directive in your configuration: {1}")] Config(&'static str, Cow<'static, str>), #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists + Conflict(Cow<'static, str>), // This is only needed for when a room alias already exists #[error(transparent)] ContentDisposition(#[from] ruma::http_headers::ContentDispositionParseError), #[error("{0}")] @@ -107,6 +112,10 @@ pub enum Error { Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), #[error(transparent)] Ruma(#[from] ruma::api::client::error::Error), + #[error(transparent)] + Signatures(#[from] ruma::signatures::Error), + #[error(transparent)] + StateRes(#[from] ruma::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), @@ -116,17 +125,19 @@ pub enum Error { } impl Error { + //#[deprecated] pub fn bad_database(message: &'static str) -> Self { crate::err!(Database(error!("{message}"))) } /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_string(&self) -> String { + pub fn sanitized_message(&self) -> String { match self { Self::Database(..) => String::from("Database error occurred."), Self::Io(..) => String::from("I/O error occurred."), - _ => self.to_string(), + _ => self.message(), } } + /// Generate the error message string. pub fn message(&self) -> String { match self { Self::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), @@ -141,25 +152,43 @@ pub fn kind(&self) -> ruma::api::client::error::ErrorKind { use ruma::api::client::error::ErrorKind::Unknown; match self { - Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::Federation(_, error) | Self::Ruma(error) => response::ruma_error_kind(error).clone(), Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(), _ => Unknown, } } + /// Returns the HTTP error code or closest approximation based on error + /// variant. pub fn status_code(&self) -> http::StatusCode { + use http::StatusCode; + match self { - Self::Federation(_, ref error) | Self::Ruma(ref error) => error.status_code, - Self::Request(ref kind, _, code) => response::status_code(kind, *code), - Self::BadRequest(ref kind, ..) => response::bad_request_code(kind), - Self::Conflict(_) => http::StatusCode::CONFLICT, - _ => http::StatusCode::INTERNAL_SERVER_ERROR, + Self::Federation(_, error) | Self::Ruma(error) => error.status_code, + Self::Request(kind, _, code) => response::status_code(kind, *code), + Self::BadRequest(kind, ..) => response::bad_request_code(kind), + Self::Reqwest(error) => error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + Self::Conflict(_) => StatusCode::CONFLICT, + _ => StatusCode::INTERNAL_SERVER_ERROR, } } + + /// Returns true for "not found" errors. This means anything that qualifies + /// as a "not found" from any variant's contained error type. This call is + /// often used as a special case to eliminate a contained Option with a + /// Result where Ok(None) is instead Err(e) if e.is_not_found(). + #[inline] + pub fn is_not_found(&self) -> bool { self.status_code() == http::StatusCode::NOT_FOUND } } impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.message()) } +} + +impl<T> From<PoisonError<T>> for Error { + #[cold] + #[inline(never)] + fn from(e: PoisonError<T>) -> Self { Self::Poison(e.to_string().into()) } } #[allow(clippy::fallible_impl_from)] @@ -176,3 +205,9 @@ fn from(_e: Infallible) -> Self { pub fn infallible(_e: &Infallible) { panic!("infallible error should never exist"); } + +/// Convenience functor for fundamental Error::sanitized_message(); see member. +#[inline] +#[must_use] +#[allow(clippy::needless_pass_by_value)] +pub fn sanitized_message(e: Error) -> String { e.sanitized_message() } diff --git a/src/core/error/panic.rs b/src/core/error/panic.rs index c070f78669ddbd484a5c8835e2bc345073dfe634..bec25132d5892699b0908f664a6b2353fa94316a 100644 --- a/src/core/error/panic.rs +++ b/src/core/error/panic.rs @@ -10,11 +10,14 @@ impl UnwindSafe for Error {} impl RefUnwindSafe for Error {} impl Error { + #[inline] pub fn panic(self) -> ! { panic_any(self.into_panic()) } #[must_use] + #[inline] pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(debug::panic_str(&e), e) } + #[inline] pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { match self { Self::Panic(_, e) | Self::PanicAny(e) => e, @@ -24,6 +27,7 @@ pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { } /// Get the panic message string. + #[inline] pub fn panic_str(self) -> Option<&'static str> { self.is_panic() .then_some(debug::panic_str(&self.into_panic())) diff --git a/src/core/error/response.rs b/src/core/error/response.rs index 7568a1c015d8cdabf7ef4220ebf58ca4e263be72..21fbdcf22e01e93f30db31c66338092cda2849be 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -26,6 +26,7 @@ fn into_response(self) -> axum::response::Response { } impl From<Error> for UiaaResponse { + #[inline] fn from(error: Error) -> Self { if let Error::Uiaa(uiaainfo) = error { return Self::AuthResponse(uiaainfo); diff --git a/src/core/log/fmt_span.rs b/src/core/log/fmt_span.rs new file mode 100644 index 0000000000000000000000000000000000000000..5a340d0fa252013428364cf43ddcbc8dc83d92df --- /dev/null +++ b/src/core/log/fmt_span.rs @@ -0,0 +1,17 @@ +use tracing_subscriber::fmt::format::FmtSpan; + +use crate::Result; + +#[inline] +pub fn from_str(str: &str) -> Result<FmtSpan, FmtSpan> { + match str.to_uppercase().as_str() { + "ENTER" => Ok(FmtSpan::ENTER), + "EXIT" => Ok(FmtSpan::EXIT), + "NEW" => Ok(FmtSpan::NEW), + "CLOSE" => Ok(FmtSpan::CLOSE), + "ACTIVE" => Ok(FmtSpan::ACTIVE), + "FULL" => Ok(FmtSpan::FULL), + "NONE" => Ok(FmtSpan::NONE), + _ => Err(FmtSpan::NONE), + } +} diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index 04d250a6d9701e42652e87749afde8f3c1f99817..48b7f0f389640a91c63ddde3756f534fc68ace8c 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -1,6 +1,9 @@ +#![allow(clippy::disallowed_macros)] + pub mod capture; pub mod color; pub mod fmt; +pub mod fmt_span; mod reload; mod suppress; @@ -27,6 +30,11 @@ pub struct Log { // necessary but discouraged. Remember debug_ log macros are also exported to // the crate namespace like these. +#[macro_export] +macro_rules! event { + ( $level:expr, $($x:tt)+ ) => { ::tracing::event!( $level, $($x)+ ) } +} + #[macro_export] macro_rules! error { ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ) } diff --git a/src/core/mod.rs b/src/core/mod.rs index 9898243bf6368da45e1b43055cc53abd2e6e7826..4ab847307b1a9bd73b7cf540c3dbfdc213280c2a 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -10,18 +10,19 @@ pub mod server; pub mod utils; +pub use ::http; +pub use ::ruma; pub use ::toml; +pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; -pub use pdu::{PduBuilder, PduCount, PduEvent}; +pub use pdu::{Event, PduBuilder, PduCount, PduEvent, PduId, RawPduId}; pub use server::Server; -pub use utils::{ctor, dtor, implement}; +pub use utils::{ctor, dtor, implement, result, result::Result}; pub use crate as conduit_core; -pub type Result<T, E = Error> = std::result::Result<T, E>; - rustc_flags_capture! {} #[cfg(not(conduit_mods))] diff --git a/src/core/pdu/builder.rs b/src/core/pdu/builder.rs index ba4c19e57229ca331c0b301eebd2358a5bc6bb9d..80ff07130fa24bb70889d728f075e5c85f1796cf 100644 --- a/src/core/pdu/builder.rs +++ b/src/core/pdu/builder.rs @@ -1,20 +1,67 @@ use std::{collections::BTreeMap, sync::Arc}; -use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch}; +use ruma::{ + events::{EventContent, MessageLikeEventType, StateEventType, TimelineEventType}, + EventId, MilliSecondsSinceUnixEpoch, +}; use serde::Deserialize; -use serde_json::value::RawValue as RawJsonValue; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; /// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] -pub struct PduBuilder { +pub struct Builder { #[serde(rename = "type")] pub event_type: TimelineEventType, + pub content: Box<RawJsonValue>, - pub unsigned: Option<BTreeMap<String, serde_json::Value>>, + + pub unsigned: Option<Unsigned>, + pub state_key: Option<String>, + pub redacts: Option<Arc<EventId>>, - /// For timestamped messaging, should only be used for appservices - /// + + /// For timestamped messaging, should only be used for appservices. /// Will be set to current time if None pub timestamp: Option<MilliSecondsSinceUnixEpoch>, } + +type Unsigned = BTreeMap<String, serde_json::Value>; + +impl Builder { + pub fn state<T>(state_key: String, content: &T) -> Self + where + T: EventContent<EventType = StateEventType>, + { + Self { + event_type: content.event_type().into(), + content: to_raw_value(content).expect("Builder failed to serialize state event content to RawValue"), + state_key: Some(state_key), + ..Self::default() + } + } + + pub fn timeline<T>(content: &T) -> Self + where + T: EventContent<EventType = MessageLikeEventType>, + { + Self { + event_type: content.event_type().into(), + content: to_raw_value(content).expect("Builder failed to serialize timeline event content to RawValue"), + ..Self::default() + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self { + event_type: "m.room.message".into(), + content: Box::<RawJsonValue>::default(), + unsigned: None, + state_key: None, + redacts: None, + timestamp: None, + } + } +} diff --git a/src/core/pdu/content.rs b/src/core/pdu/content.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa724cb2d0d3711ff768ab0fb2af90883dd0faa0 --- /dev/null +++ b/src/core/pdu/content.rs @@ -0,0 +1,20 @@ +use serde::Deserialize; +use serde_json::value::Value as JsonValue; + +use crate::{err, implement, Result}; + +#[must_use] +#[implement(super::Pdu)] +pub fn get_content_as_value(&self) -> JsonValue { + self.get_content() + .expect("pdu content must be a valid JSON value") +} + +#[implement(super::Pdu)] +pub fn get_content<T>(&self) -> Result<T> +where + T: for<'de> Deserialize<'de>, +{ + serde_json::from_str(self.content.get()) + .map_err(|e| err!(Database("Failed to deserialize pdu content into type: {e}"))) +} diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index 094988b694f9d8924fba4a14958119073acf741b..852223825ec5cb117924d95171e63081b6a137a9 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -1,51 +1,174 @@ -use std::cmp::Ordering; +#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss, clippy::as_conversions)] -use ruma::api::client::error::ErrorKind; +use std::{cmp::Ordering, fmt, fmt::Display, str::FromStr}; -use crate::{Error, Result}; +use ruma::api::Direction; + +use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] -pub enum PduCount { - Backfilled(u64), +pub enum Count { Normal(u64), + Backfilled(i64), } -impl PduCount { +impl Count { + #[inline] #[must_use] - pub fn min() -> Self { Self::Backfilled(u64::MAX) } + pub fn from_unsigned(unsigned: u64) -> Self { Self::from_signed(unsigned as i64) } + #[inline] #[must_use] - pub fn max() -> Self { Self::Normal(u64::MAX) } + pub fn from_signed(signed: i64) -> Self { + match signed { + i64::MIN..=0 => Self::Backfilled(signed), + _ => Self::Normal(signed as u64), + } + } - pub fn try_from_string(token: &str) -> Result<Self> { - if let Some(stripped_token) = token.strip_prefix('-') { - stripped_token.parse().map(PduCount::Backfilled) - } else { - token.parse().map(PduCount::Normal) + #[inline] + #[must_use] + pub fn into_unsigned(self) -> u64 { + self.debug_assert_valid(); + match self { + Self::Normal(i) => i, + Self::Backfilled(i) => i as u64, } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) } + #[inline] #[must_use] - pub fn stringify(&self) -> String { + pub fn into_signed(self) -> i64 { + self.debug_assert_valid(); match self { - Self::Backfilled(x) => format!("-{x}"), - Self::Normal(x) => x.to_string(), + Self::Normal(i) => i as i64, + Self::Backfilled(i) => i, } } -} -impl PartialOrd for PduCount { - fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } + #[inline] + #[must_use] + pub fn into_normal(self) -> Self { + self.debug_assert_valid(); + match self { + Self::Normal(i) => Self::Normal(i), + Self::Backfilled(_) => Self::Normal(0), + } + } + + #[inline] + pub fn checked_inc(self, dir: Direction) -> Result<Self, Error> { + match dir { + Direction::Forward => self.checked_add(1), + Direction::Backward => self.checked_sub(1), + } + } + + #[inline] + pub fn checked_add(self, add: u64) -> Result<Self, Error> { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_add(add) + .ok_or_else(|| err!(Arithmetic("Count::Normal overflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_add(add as i64) + .ok_or_else(|| err!(Arithmetic("Count::Backfilled overflow")))?, + ), + }) + } + + #[inline] + pub fn checked_sub(self, sub: u64) -> Result<Self, Error> { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_sub(sub) + .ok_or_else(|| err!(Arithmetic("Count::Normal underflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_sub(sub as i64) + .ok_or_else(|| err!(Arithmetic("Count::Backfilled underflow")))?, + ), + }) + } + + #[inline] + #[must_use] + pub fn saturating_inc(self, dir: Direction) -> Self { + match dir { + Direction::Forward => self.saturating_add(1), + Direction::Backward => self.saturating_sub(1), + } + } + + #[inline] + #[must_use] + pub fn saturating_add(self, add: u64) -> Self { + match self { + Self::Normal(i) => Self::Normal(i.saturating_add(add)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_add(add as i64)), + } + } + + #[inline] + #[must_use] + pub fn saturating_sub(self, sub: u64) -> Self { + match self { + Self::Normal(i) => Self::Normal(i.saturating_sub(sub)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_sub(sub as i64)), + } + } + + #[inline] + #[must_use] + pub const fn min() -> Self { Self::Backfilled(i64::MIN) } + + #[inline] + #[must_use] + pub const fn max() -> Self { Self::Normal(i64::MAX as u64) } + + #[inline] + pub(crate) fn debug_assert_valid(&self) { + if let Self::Backfilled(i) = self { + debug_assert!(*i <= 0, "Backfilled sequence must be negative"); + } + } } -impl Ord for PduCount { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (Self::Normal(s), Self::Normal(o)) => s.cmp(o), - (Self::Backfilled(s), Self::Backfilled(o)) => o.cmp(s), - (Self::Normal(_), Self::Backfilled(_)) => Ordering::Greater, - (Self::Backfilled(_), Self::Normal(_)) => Ordering::Less, +impl Display for Count { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + self.debug_assert_valid(); + match self { + Self::Normal(i) => write!(f, "{i}"), + Self::Backfilled(i) => write!(f, "{i}"), } } } + +impl From<i64> for Count { + #[inline] + fn from(signed: i64) -> Self { Self::from_signed(signed) } +} + +impl From<u64> for Count { + #[inline] + fn from(unsigned: u64) -> Self { Self::from_unsigned(unsigned) } +} + +impl FromStr for Count { + type Err = Error; + + fn from_str(token: &str) -> Result<Self, Self::Err> { Ok(Self::from_signed(token.parse()?)) } +} + +impl PartialOrd for Count { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } +} + +impl Ord for Count { + fn cmp(&self, other: &Self) -> Ordering { self.into_signed().cmp(&other.into_signed()) } +} + +impl Default for Count { + fn default() -> Self { Self::Normal(0) } +} diff --git a/src/core/pdu/event.rs b/src/core/pdu/event.rs new file mode 100644 index 0000000000000000000000000000000000000000..96a1e4ba3e1fed9ad9f37ccf8a35a3e3cf47f17d --- /dev/null +++ b/src/core/pdu/event.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +pub use ruma::state_res::Event; +use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; +use serde_json::value::RawValue as RawJsonValue; + +use super::Pdu; + +impl Event for Pdu { + type Id = Arc<EventId>; + + fn event_id(&self) -> &Self::Id { &self.event_id } + + fn room_id(&self) -> &RoomId { &self.room_id } + + fn sender(&self) -> &UserId { &self.sender } + + fn event_type(&self) -> &TimelineEventType { &self.kind } + + fn content(&self) -> &RawJsonValue { &self.content } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } + + fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } + + fn prev_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.prev_events.iter() } + + fn auth_events(&self) -> impl DoubleEndedIterator<Item = &Self::Id> + Send + '_ { self.auth_events.iter() } + + fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } +} diff --git a/src/core/pdu/event_id.rs b/src/core/pdu/event_id.rs new file mode 100644 index 0000000000000000000000000000000000000000..ae5b85f9a03be2921725dfe677ee3a8de354f46f --- /dev/null +++ b/src/core/pdu/event_id.rs @@ -0,0 +1,27 @@ +use ruma::{CanonicalJsonObject, OwnedEventId, RoomVersionId}; +use serde_json::value::RawValue as RawJsonValue; + +use crate::{err, Result}; + +/// Generates a correct eventId for the incoming pdu. +/// +/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, +/// CanonicalJsonValue>`. +pub fn gen_event_id_canonical_json( + pdu: &RawJsonValue, room_version_id: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) + .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; + + let event_id = gen_event_id(&value, room_version_id)?; + + Ok((event_id, value)) +} + +/// Generates a correct eventId for the incoming pdu. +pub fn gen_event_id(value: &CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result<OwnedEventId> { + let reference_hash = ruma::signatures::reference_hash(value, room_version_id)?; + let event_id: OwnedEventId = format!("${reference_hash}").try_into()?; + + Ok(event_id) +} diff --git a/src/core/pdu/filter.rs b/src/core/pdu/filter.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7c7316d1b395276c3d1ddb2ff2833d27cff464d --- /dev/null +++ b/src/core/pdu/filter.rs @@ -0,0 +1,90 @@ +use ruma::api::client::filter::{RoomEventFilter, UrlFilter}; +use serde_json::Value; + +use crate::{implement, is_equal_to}; + +#[implement(super::Pdu)] +#[must_use] +pub fn matches(&self, filter: &RoomEventFilter) -> bool { + if !self.matches_sender(filter) { + return false; + } + + if !self.matches_room(filter) { + return false; + } + + if !self.matches_type(filter) { + return false; + } + + if !self.matches_url(filter) { + return false; + } + + true +} + +#[implement(super::Pdu)] +fn matches_room(&self, filter: &RoomEventFilter) -> bool { + if filter.not_rooms.contains(&self.room_id) { + return false; + } + + if let Some(rooms) = filter.rooms.as_ref() { + if !rooms.contains(&self.room_id) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_sender(&self, filter: &RoomEventFilter) -> bool { + if filter.not_senders.contains(&self.sender) { + return false; + } + + if let Some(senders) = filter.senders.as_ref() { + if !senders.contains(&self.sender) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_type(&self, filter: &RoomEventFilter) -> bool { + let event_type = &self.kind.to_cow_str(); + if filter.not_types.iter().any(is_equal_to!(event_type)) { + return false; + } + + if let Some(types) = filter.types.as_ref() { + if !types.iter().any(is_equal_to!(event_type)) { + return false; + } + } + + true +} + +#[implement(super::Pdu)] +fn matches_url(&self, filter: &RoomEventFilter) -> bool { + let Some(url_filter) = filter.url_filter.as_ref() else { + return true; + }; + + //TODO: might be better to use Ruma's Raw rather than serde here + let url = serde_json::from_str::<Value>(self.content.get()) + .expect("parsing content JSON failed") + .get("url") + .is_some_and(Value::is_string); + + match url_filter { + UrlFilter::EventsWithUrl => url, + UrlFilter::EventsWithoutUrl => !url, + } +} diff --git a/src/core/pdu/id.rs b/src/core/pdu/id.rs new file mode 100644 index 0000000000000000000000000000000000000000..0b23a29f867525d607075ca56ff06ffa84a8d330 --- /dev/null +++ b/src/core/pdu/id.rs @@ -0,0 +1,22 @@ +use super::{Count, RawId}; +use crate::utils::u64_from_u8x8; + +pub type ShortRoomId = ShortId; +pub type ShortEventId = ShortId; +pub type ShortId = u64; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Id { + pub shortroomid: ShortRoomId, + pub shorteventid: Count, +} + +impl From<RawId> for Id { + #[inline] + fn from(raw: RawId) -> Self { + Self { + shortroomid: u64_from_u8x8(raw.shortroomid()), + shorteventid: Count::from_unsigned(u64_from_u8x8(raw.shorteventid())), + } + } +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 439c831a539165ebf4b0884a9cebb850c394f213..2aa60ed1e8ec97301f7ec2598ec7ce6d50c375e2 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -1,44 +1,39 @@ mod builder; +mod content; mod count; +mod event; +mod event_id; +mod filter; +mod id; +mod raw_id; +mod redact; +mod relation; +mod strip; +mod tests; +mod unsigned; + +use std::{cmp::Ordering, sync::Arc}; -use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; - -pub use builder::PduBuilder; -pub use count::PduCount; use ruma::{ - canonical_json::redact_content_in_place, - events::{ - room::{member::RoomMemberEventContent, redaction::RoomRedactionEventContent}, - space::child::HierarchySpaceChildEvent, - AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, - AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, - }, - serde::Raw, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, - OwnedUserId, RoomId, RoomVersionId, UInt, UserId, + events::TimelineEventType, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedRoomId, OwnedUserId, UInt, }; use serde::{Deserialize, Serialize}; -use serde_json::{ - json, - value::{to_raw_value, RawValue as RawJsonValue}, +use serde_json::value::RawValue as RawJsonValue; + +pub use self::{ + builder::{Builder, Builder as PduBuilder}, + count::Count, + event::Event, + event_id::*, + id::*, + raw_id::*, + Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId, }; +use crate::Result; -use crate::{err, warn, Error}; - -#[derive(Deserialize)] -struct ExtractRedactedBecause { - redacted_because: Option<serde::de::IgnoredAny>, -} - -/// Content hashes of a PDU. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct EventHash { - /// The SHA-256 hash. - pub sha256: String, -} - +/// Persistent Data Unit (Event) #[derive(Clone, Deserialize, Serialize, Debug)] -pub struct PduEvent { +pub struct Pdu { pub event_id: Arc<EventId>, pub room_id: OwnedRoomId, pub sender: OwnedUserId, @@ -59,353 +54,41 @@ pub struct PduEvent { pub unsigned: Option<Box<RawJsonValue>>, pub hashes: EventHash, #[serde(default, skip_serializing_if = "Option::is_none")] - pub signatures: Option<Box<RawJsonValue>>, /* BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, - * String>> */ + // BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, String>> + pub signatures: Option<Box<RawJsonValue>>, } -impl PduEvent { - #[tracing::instrument(skip(self), level = "debug")] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { - self.unsigned = None; - - let mut content = serde_json::from_str(self.content.get()) - .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; - redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) - .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; - - self.unsigned = Some( - to_raw_value(&json!({ - "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") - })) - .expect("to string always works"), - ); - - self.content = to_raw_value(&content).expect("to string always works"); - - Ok(()) - } - - #[must_use] - pub fn is_redacted(&self) -> bool { - let Some(unsigned) = &self.unsigned else { - return false; - }; - - let Ok(unsigned) = ExtractRedactedBecause::deserialize(&**unsigned) else { - return false; - }; - - unsigned.redacted_because.is_some() - } - - pub fn remove_transaction_id(&mut self) -> crate::Result<()> { - if let Some(unsigned) = &self.unsigned { - let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - } - - Ok(()) - } - - pub fn add_age(&mut self) -> crate::Result<()> { - let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self - .unsigned - .as_ref() - .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - - // deliberately allowing for the possibility of negative age - let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); - let then: i128 = self.origin_server_ts.into(); - let this_age = now.saturating_sub(then); - - unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - - Ok(()) - } - - /// Copies the `redacts` property of the event to the `content` dict and - /// vice-versa. - /// - /// This follows the specification's - /// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): - /// - /// > For backwards-compatibility with older clients, servers should add a - /// > redacts - /// > property to the top level of m.room.redaction events in when serving - /// > such events - /// > over the Client-Server API. - /// - /// > For improved compatibility with newer clients, servers should add a - /// > redacts property - /// > to the content of m.room.redaction events in older room versions when - /// > serving - /// > such events over the Client-Server API. - #[must_use] - pub fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) { - if self.kind == TimelineEventType::RoomRedaction { - if let Ok(mut content) = serde_json::from_str::<RoomRedactionEventContent>(self.content.get()) { - if let Some(redacts) = content.redacts { - return (Some(redacts.into()), self.content.clone()); - } else if let Some(redacts) = self.redacts.clone() { - content.redacts = Some(redacts.into()); - return ( - self.redacts.clone(), - to_raw_value(&content).expect("Must be valid, we only added redacts field"), - ); - } - } - } - - (self.redacts.clone(), self.content.clone()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - /// This only works for events that are also AnyRoomEvents. - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> { - let (redacts, content) = self.copy_redacts(); - let mut json = json!({ - "content": content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &redacts { - json["redacts"] = json!(redacts); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_state_event(&self) -> Raw<AnyStateEvent> { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - }); - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - "origin_server_ts": self.origin_server_ts, - }); - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "redacts": self.redacts, - "room_id": self.room_id, - "state_key": self.state_key, - }); - - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - - serde_json::from_value(json).expect("Raw::from_value always works") - } - - pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> { - json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); +/// Content hashes of a PDU. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct EventHash { + /// The SHA-256 hash. + pub sha256: String, +} - serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) +impl Pdu { + pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self> { + let event_id = CanonicalJsonValue::String(event_id.into()); + json.insert("event_id".into(), event_id); + serde_json::to_value(json) + .and_then(serde_json::from_value) + .map_err(Into::into) } } -impl state_res::Event for PduEvent { - type Id = Arc<EventId>; - - fn event_id(&self) -> &Self::Id { &self.event_id } - - fn room_id(&self) -> &RoomId { &self.room_id } - - fn sender(&self) -> &UserId { &self.sender } - - fn event_type(&self) -> &TimelineEventType { &self.kind } - - fn content(&self) -> &RawJsonValue { &self.content } - - fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } +/// Prevent derived equality which wouldn't limit itself to event_id +impl Eq for Pdu {} - fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - - fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) } - - fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) } - - fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } -} - -// These impl's allow us to dedup state snapshots when resolving state -// for incoming events (federation/send/{txn}). -impl Eq for PduEvent {} -impl PartialEq for PduEvent { +/// Equality determined by the Pdu's ID, not the memory representations. +impl PartialEq for Pdu { fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id } } -impl PartialOrd for PduEvent { + +/// Ordering determined by the Pdu's ID, not the memory representations. +impl PartialOrd for Pdu { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } } -impl Ord for PduEvent { - fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) } -} - -/// Generates a correct eventId for the incoming pdu. -/// -/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, -/// CanonicalJsonValue>`. -pub fn gen_event_id_canonical_json( - pdu: &RawJsonValue, room_version_id: &RoomVersionId, -) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()) - .map_err(|e| err!(BadServerResponse(warn!("Error parsing incoming event: {e:?}"))))?; - let event_id = format!( - "${}", - // Anything higher than version3 behaves the same - ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes") - ) - .try_into() - .expect("ruma's reference hashes are valid event ids"); - - Ok((event_id, value)) +/// Ordering determined by the Pdu's ID, not the memory representations. +impl Ord for Pdu { + fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) } } diff --git a/src/core/pdu/raw_id.rs b/src/core/pdu/raw_id.rs new file mode 100644 index 0000000000000000000000000000000000000000..ef8502f685162163b608e40b66881261534cfbe4 --- /dev/null +++ b/src/core/pdu/raw_id.rs @@ -0,0 +1,113 @@ +use arrayvec::ArrayVec; + +use super::{Count, Id, ShortEventId, ShortId, ShortRoomId}; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum RawId { + Normal(RawIdNormal), + Backfilled(RawIdBackfilled), +} + +type RawIdNormal = [u8; RawId::NORMAL_LEN]; +type RawIdBackfilled = [u8; RawId::BACKFILLED_LEN]; + +const INT_LEN: usize = size_of::<ShortId>(); + +impl RawId { + const BACKFILLED_LEN: usize = size_of::<ShortRoomId>() + INT_LEN + size_of::<ShortEventId>(); + const MAX_LEN: usize = Self::BACKFILLED_LEN; + const NORMAL_LEN: usize = size_of::<ShortRoomId>() + size_of::<ShortEventId>(); + + #[inline] + #[must_use] + pub fn pdu_count(&self) -> Count { + let id: Id = (*self).into(); + id.shorteventid + } + + #[inline] + #[must_use] + pub fn shortroomid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[0..INT_LEN] + .try_into() + .expect("normal raw shortroomid array from slice"), + Self::Backfilled(raw) => raw[0..INT_LEN] + .try_into() + .expect("backfilled raw shortroomid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn shorteventid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[INT_LEN..INT_LEN * 2] + .try_into() + .expect("normal raw shorteventid array from slice"), + Self::Backfilled(raw) => raw[INT_LEN * 2..INT_LEN * 3] + .try_into() + .expect("backfilled raw shorteventid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Normal(ref raw) => raw, + Self::Backfilled(ref raw) => raw, + } + } +} + +impl AsRef<[u8]> for RawId { + #[inline] + fn as_ref(&self) -> &[u8] { self.as_bytes() } +} + +impl From<&[u8]> for RawId { + #[inline] + fn from(id: &[u8]) -> Self { + match id.len() { + Self::NORMAL_LEN => Self::Normal( + id[0..Self::NORMAL_LEN] + .try_into() + .expect("normal RawId from [u8]"), + ), + Self::BACKFILLED_LEN => Self::Backfilled( + id[0..Self::BACKFILLED_LEN] + .try_into() + .expect("backfilled RawId from [u8]"), + ), + _ => unimplemented!("unrecognized RawId length"), + } + } +} + +impl From<Id> for RawId { + #[inline] + fn from(id: Id) -> Self { + const MAX_LEN: usize = RawId::MAX_LEN; + type RawVec = ArrayVec<u8, MAX_LEN>; + + let mut vec = RawVec::new(); + vec.extend(id.shortroomid.to_be_bytes()); + id.shorteventid.debug_assert_valid(); + match id.shorteventid { + Count::Normal(shorteventid) => { + vec.extend(shorteventid.to_be_bytes()); + Self::Normal(vec.as_ref().try_into().expect("RawVec into RawId::Normal")) + }, + Count::Backfilled(shorteventid) => { + vec.extend(0_u64.to_be_bytes()); + vec.extend(shorteventid.to_be_bytes()); + Self::Backfilled( + vec.as_ref() + .try_into() + .expect("RawVec into RawId::Backfilled"), + ) + }, + } + } +} diff --git a/src/core/pdu/redact.rs b/src/core/pdu/redact.rs new file mode 100644 index 0000000000000000000000000000000000000000..e116e563d636fc7d0c0aa689c9ecee041a1cba07 --- /dev/null +++ b/src/core/pdu/redact.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use ruma::{ + canonical_json::redact_content_in_place, + events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, + EventId, RoomVersionId, +}; +use serde::Deserialize; +use serde_json::{ + json, + value::{to_raw_value, RawValue as RawJsonValue}, +}; + +use crate::{implement, warn, Error, Result}; + +#[derive(Deserialize)] +struct ExtractRedactedBecause { + redacted_because: Option<serde::de::IgnoredAny>, +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> Result { + self.unsigned = None; + + let mut content = + serde_json::from_str(self.content.get()).map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) + .map_err(|e| Error::Redaction(self.sender.server_name().to_owned(), e))?; + + self.unsigned = Some( + to_raw_value(&json!({ + "redacted_because": serde_json::to_value(reason).expect("to_value(Pdu) always works") + })) + .expect("to string always works"), + ); + + self.content = to_raw_value(&content).expect("to string always works"); + + Ok(()) +} + +#[implement(super::Pdu)] +#[must_use] +pub fn is_redacted(&self) -> bool { + let Some(unsigned) = &self.unsigned else { + return false; + }; + + let Ok(unsigned) = ExtractRedactedBecause::deserialize(&**unsigned) else { + return false; + }; + + unsigned.redacted_because.is_some() +} + +/// Copies the `redacts` property of the event to the `content` dict and +/// vice-versa. +/// +/// This follows the specification's +/// [recommendation](https://spec.matrix.org/v1.10/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property): +/// +/// > For backwards-compatibility with older clients, servers should add a +/// > redacts +/// > property to the top level of m.room.redaction events in when serving +/// > such events +/// > over the Client-Server API. +/// +/// > For improved compatibility with newer clients, servers should add a +/// > redacts property +/// > to the content of m.room.redaction events in older room versions when +/// > serving +/// > such events over the Client-Server API. +#[implement(super::Pdu)] +#[must_use] +pub fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) { + if self.kind == TimelineEventType::RoomRedaction { + if let Ok(mut content) = serde_json::from_str::<RoomRedactionEventContent>(self.content.get()) { + if let Some(redacts) = content.redacts { + return (Some(redacts.into()), self.content.clone()); + } else if let Some(redacts) = self.redacts.clone() { + content.redacts = Some(redacts.into()); + return ( + self.redacts.clone(), + to_raw_value(&content).expect("Must be valid, we only added redacts field"), + ); + } + } + } + + (self.redacts.clone(), self.content.clone()) +} diff --git a/src/core/pdu/relation.rs b/src/core/pdu/relation.rs new file mode 100644 index 0000000000000000000000000000000000000000..2968171e3a6cd32042c2063b00882ffb4225b072 --- /dev/null +++ b/src/core/pdu/relation.rs @@ -0,0 +1,22 @@ +use ruma::events::relation::RelationType; +use serde::Deserialize; + +use crate::implement; + +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelType { + rel_type: RelationType, +} +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelatesToEventId { + #[serde(rename = "m.relates_to")] + relates_to: ExtractRelType, +} + +#[implement(super::Pdu)] +#[must_use] +pub fn relation_type_equal(&self, rel_type: &RelationType) -> bool { + self.get_content() + .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) + .is_ok_and(|r| r == *rel_type) +} diff --git a/src/core/pdu/strip.rs b/src/core/pdu/strip.rs new file mode 100644 index 0000000000000000000000000000000000000000..30fee863cc3e741cd8a96e272c9bcc0809aefd15 --- /dev/null +++ b/src/core/pdu/strip.rs @@ -0,0 +1,208 @@ +use ruma::{ + events::{ + room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent, + AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, + AnyTimelineEvent, StateEvent, + }, + serde::Raw, +}; +use serde_json::{json, value::Value as JsonValue}; + +use crate::{implement, warn}; + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +/// This only works for events that are also AnyRoomEvents. +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> { + let (redacts, content) = self.copy_redacts(); + let mut json = json!({ + "content": content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &redacts { + json["redacts"] = json!(redacts); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[must_use] +pub fn to_state_event_value(&self) -> JsonValue { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + json +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_state_event(&self) -> Raw<AnyStateEvent> { + serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + }); + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + "origin_server_ts": self.origin_server_ts, + }); + + serde_json::from_value(json).expect("Raw::from_value always works") +} + +#[implement(super::Pdu)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "redacts": self.redacts, + "room_id": self.room_id, + "state_key": self.state_key, + }); + + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + + serde_json::from_value(json).expect("Raw::from_value always works") +} diff --git a/src/core/pdu/tests.rs b/src/core/pdu/tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..ae3b1dd6da2806d6875ccb6de4cdf13570a1018b --- /dev/null +++ b/src/core/pdu/tests.rs @@ -0,0 +1,19 @@ +#![cfg(test)] + +use super::Count; + +#[test] +fn backfilled_parse() { + let count: Count = "-987654".parse().expect("parse() failed"); + let backfilled = matches!(count, Count::Backfilled(_)); + + assert!(backfilled, "not backfilled variant"); +} + +#[test] +fn normal_parse() { + let count: Count = "987654".parse().expect("parse() failed"); + let backfilled = matches!(count, Count::Backfilled(_)); + + assert!(!backfilled, "backfilled variant"); +} diff --git a/src/core/pdu/unsigned.rs b/src/core/pdu/unsigned.rs new file mode 100644 index 0000000000000000000000000000000000000000..6f3e4401644b4609c5434bc52ea7735bb6de5dbb --- /dev/null +++ b/src/core/pdu/unsigned.rs @@ -0,0 +1,110 @@ +use std::collections::BTreeMap; + +use ruma::MilliSecondsSinceUnixEpoch; +use serde::Deserialize; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue, Value as JsonValue}; + +use super::Pdu; +use crate::{err, implement, is_true, Result}; + +#[implement(Pdu)] +pub fn remove_transaction_id(&mut self) -> Result { + let Some(unsigned) = &self.unsigned else { + return Ok(()); + }; + + let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = + serde_json::from_str(unsigned.get()).map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + unsigned.remove("transaction_id"); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(Pdu)] +pub fn add_age(&mut self) -> Result { + let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = self + .unsigned + .as_ref() + .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + // deliberately allowing for the possibility of negative age + let now: i128 = MilliSecondsSinceUnixEpoch::now().get().into(); + let then: i128 = self.origin_server_ts.into(); + let this_age = now.saturating_sub(then); + + unsigned.insert("age".to_owned(), to_raw_value(&this_age).expect("age is valid")); + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(Pdu)] +pub fn add_relation(&mut self, name: &str, pdu: &Pdu) -> Result { + let mut unsigned: BTreeMap<String, JsonValue> = self + .unsigned + .as_ref() + .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) + .map_err(|e| err!(Database("Invalid unsigned in pdu event: {e}")))?; + + let relations: &mut JsonValue = unsigned.entry("m.relations".into()).or_default(); + if relations.as_object_mut().is_none() { + let mut object = serde_json::Map::<String, JsonValue>::new(); + _ = relations.as_object_mut().insert(&mut object); + } + + relations + .as_object_mut() + .expect("we just created it") + .insert(name.to_owned(), serde_json::to_value(pdu)?); + + self.unsigned = to_raw_value(&unsigned) + .map(Some) + .expect("unsigned is valid"); + + Ok(()) +} + +#[implement(Pdu)] +pub fn contains_unsigned_property<F>(&self, property: &str, is_type: F) -> bool +where + F: FnOnce(&JsonValue) -> bool, +{ + self.get_unsigned_as_value() + .get(property) + .map(is_type) + .is_some_and(is_true!()) +} + +#[implement(Pdu)] +pub fn get_unsigned_property<T>(&self, property: &str) -> Result<T> +where + T: for<'de> Deserialize<'de>, +{ + self.get_unsigned_as_value() + .get_mut(property) + .map(JsonValue::take) + .map(serde_json::from_value) + .ok_or(err!(Request(NotFound("property not found in unsigned object"))))? + .map_err(|e| err!(Database("Failed to deserialize unsigned.{property} into type: {e}"))) +} + +#[implement(Pdu)] +#[must_use] +pub fn get_unsigned_as_value(&self) -> JsonValue { self.get_unsigned::<JsonValue>().unwrap_or_default() } + +#[implement(Pdu)] +pub fn get_unsigned<T>(&self) -> Result<JsonValue> { + self.unsigned + .as_ref() + .map(|raw| raw.get()) + .map(serde_json::from_str) + .ok_or(err!(Request(NotFound("\"unsigned\" property not found in pdu"))))? + .map_err(|e| err!(Database("Failed to deserialize \"unsigned\" into value: {e}"))) +} diff --git a/src/core/server.rs b/src/core/server.rs index 89f1dea58e394413f4bad280799ca076f4ed3202..627e125d630b1497ce648acec4b99ff8adfe4f76 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -5,7 +5,7 @@ use tokio::{runtime, sync::broadcast}; -use crate::{config::Config, log::Log, metrics::Metrics, Err, Result}; +use crate::{config::Config, err, log::Log, metrics::Metrics, Err, Result}; /// Server runtime state; public portion pub struct Server { @@ -107,6 +107,13 @@ pub fn runtime(&self) -> &runtime::Handle { .expect("runtime handle available in Server") } + #[inline] + pub fn check_running(&self) -> Result { + self.running() + .then_some(()) + .ok_or_else(|| err!(debug_warn!("Server is shutting down."))) + } + #[inline] pub fn running(&self) -> bool { !self.stopping.load(Ordering::Acquire) } diff --git a/src/core/utils/arrayvec.rs b/src/core/utils/arrayvec.rs new file mode 100644 index 0000000000000000000000000000000000000000..685aaf18c1af877199a0696d168488bcf1119bca --- /dev/null +++ b/src/core/utils/arrayvec.rs @@ -0,0 +1,15 @@ +use ::arrayvec::ArrayVec; + +pub trait ArrayVecExt<T> { + fn extend_from_slice(&mut self, other: &[T]) -> &mut Self; +} + +impl<T: Copy, const CAP: usize> ArrayVecExt<T> for ArrayVec<T, CAP> { + #[inline] + fn extend_from_slice(&mut self, other: &[T]) -> &mut Self { + self.try_extend_from_slice(other) + .expect("Insufficient buffer capacity to extend from slice"); + + self + } +} diff --git a/src/core/utils/bool.rs b/src/core/utils/bool.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9f399d49d98733966a2b878fe90b169eb03d987 --- /dev/null +++ b/src/core/utils/bool.rs @@ -0,0 +1,88 @@ +//! Trait BoolExt + +/// Boolean extensions and chain.starters +pub trait BoolExt { + #[must_use] + fn clone_or<T: Clone>(self, err: T, t: &T) -> T; + + #[must_use] + fn copy_or<T: Copy>(self, err: T, t: T) -> T; + + #[must_use] + fn expect(self, msg: &str) -> Self; + + #[must_use] + fn expect_false(self, msg: &str) -> Self; + + fn into_option(self) -> Option<()>; + + #[allow(clippy::result_unit_err)] + fn into_result(self) -> Result<(), ()>; + + fn map<T, F: FnOnce(Self) -> T>(self, f: F) -> T + where + Self: Sized; + + fn map_ok_or<T, E, F: FnOnce() -> T>(self, err: E, f: F) -> Result<T, E>; + + fn map_or<T, F: FnOnce() -> T>(self, err: T, f: F) -> T; + + fn map_or_else<T, F: FnOnce() -> T>(self, err: F, f: F) -> T; + + fn ok_or<E>(self, err: E) -> Result<(), E>; + + fn ok_or_else<E, F: FnOnce() -> E>(self, err: F) -> Result<(), E>; + + fn or<T, F: FnOnce() -> T>(self, f: F) -> Option<T>; + + fn or_some<T>(self, t: T) -> Option<T>; +} + +impl BoolExt for bool { + #[inline] + fn clone_or<T: Clone>(self, err: T, t: &T) -> T { self.map_or(err, || t.clone()) } + + #[inline] + fn copy_or<T: Copy>(self, err: T, t: T) -> T { self.map_or(err, || t) } + + #[inline] + fn expect(self, msg: &str) -> Self { self.then_some(true).expect(msg) } + + #[inline] + fn expect_false(self, msg: &str) -> Self { (!self).then_some(false).expect(msg) } + + #[inline] + fn into_option(self) -> Option<()> { self.then_some(()) } + + #[inline] + fn into_result(self) -> Result<(), ()> { self.ok_or(()) } + + #[inline] + fn map<T, F: FnOnce(Self) -> T>(self, f: F) -> T + where + Self: Sized, + { + f(self) + } + + #[inline] + fn map_ok_or<T, E, F: FnOnce() -> T>(self, err: E, f: F) -> Result<T, E> { self.ok_or(err).map(|()| f()) } + + #[inline] + fn map_or<T, F: FnOnce() -> T>(self, err: T, f: F) -> T { self.then(f).unwrap_or(err) } + + #[inline] + fn map_or_else<T, F: FnOnce() -> T>(self, err: F, f: F) -> T { self.then(f).unwrap_or_else(err) } + + #[inline] + fn ok_or<E>(self, err: E) -> Result<(), E> { self.into_option().ok_or(err) } + + #[inline] + fn ok_or_else<E, F: FnOnce() -> E>(self, err: F) -> Result<(), E> { self.into_option().ok_or_else(err) } + + #[inline] + fn or<T, F: FnOnce() -> T>(self, f: F) -> Option<T> { (!self).then(f) } + + #[inline] + fn or_some<T>(self, t: T) -> Option<T> { (!self).then_some(t) } +} diff --git a/src/core/utils/bytes.rs b/src/core/utils/bytes.rs index e8975a491205dafa1e3097e62b4af641ab5261b4..441ba422a3cf21e76d0d415282430b76d458b021 100644 --- a/src/core/utils/bytes.rs +++ b/src/core/utils/bytes.rs @@ -1,4 +1,32 @@ -use crate::Result; +use bytesize::ByteSize; + +use crate::{err, Result}; + +/// Parse a human-writable size string w/ si-unit suffix into integer +#[inline] +pub fn from_str(str: &str) -> Result<usize> { + let bytes: ByteSize = str + .parse() + .map_err(|e| err!(Arithmetic("Failed to parse byte size: {e}")))?; + + let bytes: usize = bytes + .as_u64() + .try_into() + .map_err(|e| err!(Arithmetic("Failed to convert u64 to usize: {e}")))?; + + Ok(bytes) +} + +/// Output a human-readable size string w/ si-unit suffix +#[inline] +#[must_use] +pub fn pretty(bytes: usize) -> String { + const SI_UNITS: bool = true; + + let bytes: u64 = bytes.try_into().expect("failed to convert usize to u64"); + + bytesize::to_string(bytes, SI_UNITS) +} #[inline] #[must_use] diff --git a/src/core/utils/content_disposition.rs b/src/core/utils/content_disposition.rs index a2fe923c405a94d7b88219cdf0ce3e2b18a23f86..3a264a74f9251393b20324d955c6b00fb5a249e0 100644 --- a/src/core/utils/content_disposition.rs +++ b/src/core/utils/content_disposition.rs @@ -45,9 +45,10 @@ pub fn content_disposition_type(content_type: Option<&str>) -> ContentDispositio return ContentDispositionType::Attachment; }; - // is_sorted is unstable - /* debug_assert!(ALLOWED_INLINE_CONTENT_TYPES.is_sorted(), - * "ALLOWED_INLINE_CONTENT_TYPES is not sorted"); */ + debug_assert!( + ALLOWED_INLINE_CONTENT_TYPES.is_sorted(), + "ALLOWED_INLINE_CONTENT_TYPES is not sorted" + ); let content_type: Cow<'_, str> = content_type .split(';') diff --git a/src/core/utils/defer.rs b/src/core/utils/defer.rs index 08477b6f5e850738a12077d475cbba58e2a5ba30..29199700bd134f3b66ac2baebd73f51d16fa80d5 100644 --- a/src/core/utils/defer.rs +++ b/src/core/utils/defer.rs @@ -15,8 +15,14 @@ impl<F: FnMut()> Drop for _Defer_<F> { }; ($body:expr) => { - $crate::defer! {{ - $body - }} + $crate::defer! {{ $body }} + }; +} + +#[macro_export] +macro_rules! scope_restore { + ($val:ident, $ours:expr) => { + let theirs = $crate::utils::exchange($val, $ours); + $crate::defer! {{ *$val = theirs; }}; }; } diff --git a/src/core/utils/future/mod.rs b/src/core/utils/future/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..6d45b656391d2b7b26d12fec5133ad4cd6430e34 --- /dev/null +++ b/src/core/utils/future/mod.rs @@ -0,0 +1,3 @@ +mod try_ext_ext; + +pub use try_ext_ext::TryExtExt; diff --git a/src/core/utils/future/try_ext_ext.rs b/src/core/utils/future/try_ext_ext.rs new file mode 100644 index 0000000000000000000000000000000000000000..f97ae885216fd43eaa88e070932a2cde5fe72894 --- /dev/null +++ b/src/core/utils/future/try_ext_ext.rs @@ -0,0 +1,83 @@ +//! Extended external extensions to futures::TryFutureExt +#![allow(clippy::type_complexity)] + +use futures::{ + future::{MapOkOrElse, UnwrapOrElse}, + TryFuture, TryFutureExt, +}; + +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryExtExt<T, E> +where + Self: TryFuture<Ok = T, Error = E> + Send, +{ + /// Resolves to a bool for whether the TryFuture (Future of a Result) + /// resolved to Ok or Err. + /// + /// is_ok() has to consume *self rather than borrow. The intent of this + /// extension is therefor for a caller only ever caring about result status + /// while discarding all contents. + #[allow(clippy::wrong_self_convention)] + fn is_ok(self) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> bool, impl FnOnce(Self::Error) -> bool> + where + Self: Sized; + + fn map_ok_or<U, F>( + self, default: U, f: F, + ) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized; + + fn ok( + self, + ) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> Option<Self::Ok>, impl FnOnce(Self::Error) -> Option<Self::Ok>> + where + Self: Sized; + + fn unwrap_or(self, default: Self::Ok) -> UnwrapOrElse<Self, impl FnOnce(Self::Error) -> Self::Ok> + where + Self: Sized; +} + +impl<T, E, Fut> TryExtExt<T, E> for Fut +where + Fut: TryFuture<Ok = T, Error = E> + Send, +{ + #[inline] + fn is_ok(self) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> bool, impl FnOnce(Self::Error) -> bool> + where + Self: Sized, + { + self.map_ok_or(false, |_| true) + } + + #[inline] + fn map_ok_or<U, F>( + self, default: U, f: F, + ) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> U, impl FnOnce(Self::Error) -> U> + where + F: FnOnce(Self::Ok) -> U, + Self: Send + Sized, + { + self.map_ok_or_else(|_| default, f) + } + + #[inline] + fn ok( + self, + ) -> MapOkOrElse<Self, impl FnOnce(Self::Ok) -> Option<Self::Ok>, impl FnOnce(Self::Error) -> Option<Self::Ok>> + where + Self: Sized, + { + self.map_ok_or(None, Some) + } + + #[inline] + fn unwrap_or(self, default: Self::Ok) -> UnwrapOrElse<Self, impl FnOnce(Self::Error) -> Self::Ok> + where + Self: Sized, + { + self.unwrap_or_else(move |_| default) + } +} diff --git a/src/core/utils/hash.rs b/src/core/utils/hash.rs index 5a3664cb62021287249f38f08c68bff6cd7816c2..c12d4f663fc7c1f736feadbc1b26a3b4d1963c5f 100644 --- a/src/core/utils/hash.rs +++ b/src/core/utils/hash.rs @@ -1,13 +1,10 @@ mod argon; -mod sha256; +pub mod sha256; use crate::Result; -pub fn password(password: &str) -> Result<String> { argon::password(password) } - -pub fn verify_password(password: &str, password_hash: &str) -> Result<()> { +pub fn verify_password(password: &str, password_hash: &str) -> Result { argon::verify_password(password, password_hash) } -#[must_use] -pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> { sha256::hash(keys) } +pub fn password(password: &str) -> Result<String> { argon::password(password) } diff --git a/src/core/utils/hash/sha256.rs b/src/core/utils/hash/sha256.rs index b2e5a94c28222c259597daaef63ac565b0c5bee7..06e210a7e18aec70a3a0be06d65b9633889d5363 100644 --- a/src/core/utils/hash/sha256.rs +++ b/src/core/utils/hash/sha256.rs @@ -1,9 +1,62 @@ -use ring::{digest, digest::SHA256}; - -#[tracing::instrument(skip_all, level = "debug")] -pub(super) fn hash(keys: &[&[u8]]) -> Vec<u8> { - // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xFF); - let hash = digest::digest(&SHA256, &bytes); - hash.as_ref().to_owned() +use ring::{ + digest, + digest::{Context, SHA256, SHA256_OUTPUT_LEN}, +}; + +pub type Digest = [u8; SHA256_OUTPUT_LEN]; + +/// Sha256 hash (input gather joined by 0xFF bytes) +#[must_use] +#[tracing::instrument(skip(inputs), level = "trace")] +pub fn delimited<'a, T, I>(mut inputs: I) -> Digest +where + I: Iterator<Item = T> + 'a, + T: AsRef<[u8]> + 'a, +{ + let mut ctx = Context::new(&SHA256); + if let Some(input) = inputs.next() { + ctx.update(input.as_ref()); + for input in inputs { + ctx.update(b"\xFF"); + ctx.update(input.as_ref()); + } + } + + ctx.finish() + .as_ref() + .try_into() + .expect("failed to return Digest buffer") +} + +/// Sha256 hash (input gather) +#[must_use] +#[tracing::instrument(skip(inputs), level = "trace")] +pub fn concat<'a, T, I>(inputs: I) -> Digest +where + I: Iterator<Item = T> + 'a, + T: AsRef<[u8]> + 'a, +{ + inputs + .fold(Context::new(&SHA256), |mut ctx, input| { + ctx.update(input.as_ref()); + ctx + }) + .finish() + .as_ref() + .try_into() + .expect("failed to return Digest buffer") +} + +/// Sha256 hash +#[inline] +#[must_use] +#[tracing::instrument(skip(input), level = "trace")] +pub fn hash<T>(input: T) -> Digest +where + T: AsRef<[u8]>, +{ + digest::digest(&SHA256, input.as_ref()) + .as_ref() + .try_into() + .expect("failed to return Digest buffer") } diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs index f9d0de3022fcc600d8ed7296f2697aa4c4a0d975..ccff6400d9f9d415071e74386b96e91811a9ddc1 100644 --- a/src/core/utils/math.rs +++ b/src/core/utils/math.rs @@ -7,32 +7,50 @@ /// Checked arithmetic expression. Returns a Result<R, Error::Arithmetic> #[macro_export] macro_rules! checked { - ($($input:tt)*) => { - $crate::utils::math::checked_ops!($($input)*) + ($($input:tt)+) => { + $crate::utils::math::checked_ops!($($input)+) .ok_or_else(|| $crate::err!(Arithmetic("operation overflowed or result invalid"))) - } + }; +} + +/// Checked arithmetic expression which panics on failure. This is for +/// expressions which do not meet the threshold for validated! but the caller +/// has no realistic expectation for error and no interest in cluttering the +/// callsite with result handling from checked!. +#[macro_export] +macro_rules! expected { + ($msg:literal, $($input:tt)+) => { + $crate::checked!($($input)+).expect($msg) + }; + + ($($input:tt)+) => { + $crate::expected!("arithmetic expression expectation failure", $($input)+) + }; } -/// in release-mode. Use for performance when the expression is obviously safe. -/// The check remains in debug-mode for regression analysis. +/// Unchecked arithmetic expression in release-mode. Use for performance when +/// the expression is obviously safe. The check remains in debug-mode for +/// regression analysis. #[cfg(not(debug_assertions))] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { + ($($input:tt)+) => { //#[allow(clippy::arithmetic_side_effects)] { //Some($($input)*) // .ok_or_else(|| $crate::err!(Arithmetic("this error should never been seen"))) //} //NOTE: remove me when stmt_expr_attributes is stable - $crate::checked!($($input)*) - } + $crate::expected!("validated arithmetic expression failed", $($input)+) + }; } +/// Checked arithmetic expression in debug-mode. Use for performance when +/// the expression is obviously safe. The check is elided in release-mode. #[cfg(debug_assertions)] #[macro_export] macro_rules! validated { - ($($input:tt)*) => { $crate::checked!($($input)*) } + ($($input:tt)+) => { $crate::expected!($($input)+) } } /// Returns false if the exponential backoff has expired based on the inputs @@ -100,3 +118,6 @@ fn try_into_err<Dst: TryFrom<Src>, Src>(e: <Dst as TryFrom<Src>>::Error) -> Erro type_name::<Dst>() )) } + +#[inline] +pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 1556646ecc5e1ae80db70e3801951cbf4fbf7e1e..18c2dd6f315579a05cf9e0b673fee56381aa99ec 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,73 +1,156 @@ +pub mod arrayvec; +pub mod bool; pub mod bytes; pub mod content_disposition; pub mod debug; pub mod defer; +pub mod future; pub mod hash; pub mod html; pub mod json; pub mod math; pub mod mutex_map; pub mod rand; +pub mod result; +pub mod set; +pub mod stream; pub mod string; pub mod sys; mod tests; pub mod time; -use std::cmp::{self, Ordering}; - +pub use ::conduit_macros::implement; pub use ::ctor::{ctor, dtor}; -pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; -pub use conduit_macros::implement; -pub use debug::slice_truncated as debug_slice_truncated; -pub use hash::calculate_hash; -pub use html::Escape as HtmlEscape; -pub use json::{deserialize_from_str, to_canonical_object}; -pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; -pub use rand::string as random_string; -pub use string::{str_from_bytes, string_from_bytes}; -pub use sys::available_parallelism; -pub use time::now_millis as millis_since_unix_epoch; -#[inline] -pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } +pub use self::{ + arrayvec::ArrayVecExt, + bool::BoolExt, + bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}, + debug::slice_truncated as debug_slice_truncated, + future::TryExtExt as TryFutureExtExt, + hash::sha256::delimited as calculate_hash, + html::Escape as HtmlEscape, + json::{deserialize_from_str, to_canonical_object}, + math::clamp, + mutex_map::{Guard as MutexMapGuard, MutexMap}, + rand::{shuffle, string as random_string}, + stream::{IterStream, ReadyExt, Tools as StreamTools, TryReadyExt}, + string::{str_from_bytes, string_from_bytes}, + sys::available_parallelism, + time::{now_millis as millis_since_unix_epoch, timepoint_ago, timepoint_from_now}, +}; #[inline] -pub fn exchange<T: Clone>(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret +pub fn exchange<T>(state: &mut T, source: T) -> T { std::mem::replace(state, source) } + +#[macro_export] +macro_rules! extract_variant { + ($e:expr, $variant:path) => { + match $e { + $variant(value) => Some(value), + _ => None, + } + }; +} + +#[macro_export] +macro_rules! apply { + (1, $($idx:tt)+) => { + |t| (($($idx)+)(t.0),) + }; + + (2, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1),) + }; + + (3, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1), ($($idx)+)(t.2),) + }; + + (4, $($idx:tt)+) => { + |t| (($($idx)+)(t.0), ($($idx)+)(t.1), ($($idx)+)(t.2), ($($idx)+4)(t.3)) + }; +} + +#[macro_export] +macro_rules! at { + ($idx:tt) => { + |t| t.$idx + }; +} + +/// Functor for equality i.e. .is_some_and(is_equal!(2)) +#[macro_export] +macro_rules! is_equal_to { + ($val:ident) => { + |x| x == $val + }; + + ($val:expr) => { + |x| x == $val + }; +} + +/// Functor for less i.e. .is_some_and(is_less_than!(2)) +#[macro_export] +macro_rules! is_less_than { + ($val:ident) => { + |x| x < $val + }; + + ($val:expr) => { + |x| x < $val + }; +} + +/// Functor for equality to zero +#[macro_export] +macro_rules! is_zero { + () => { + $crate::is_matching!(0) + }; +} + +/// Functor for matches! i.e. .is_some_and(is_matching!('A'..='Z')) +#[macro_export] +macro_rules! is_matching { + ($val:ident) => { + |x| matches!(x, $val) + }; + + ($val:expr) => { + |x| matches!(x, $val) + }; +} + +/// Functor for !is_empty() +#[macro_export] +macro_rules! is_not_empty { + () => { + |x| !x.is_empty() + }; +} + +/// Functor for equality i.e. (a, b).map(is_equal!()) +#[macro_export] +macro_rules! is_equal { + () => { + |a, b| a == b + }; } -#[must_use] -pub fn generate_keypair() -> Vec<u8> { - let mut value = rand::string(8).as_bytes().to_vec(); - value.push(0xFF); - value.extend_from_slice( - &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), - ); - value +/// Functor for truthy +#[macro_export] +macro_rules! is_true { + () => { + |x| !!x + }; } -#[allow(clippy::impl_trait_in_params)] -pub fn common_elements( - mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, -) -> Option<impl Iterator<Item = Vec<u8>>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators.map(Iterator::peekable).collect::<Vec<_>>(); - - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - }, - } - } - false - }) - })) +/// Functor for falsy +#[macro_export] +macro_rules! is_false { + () => { + |x| !x + }; } diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs index b80671eb90c00f35019a247e3610ab32aad3c020..9e6fc7a816fb94d462eaf617d243e13c1df27728 100644 --- a/src/core/utils/rand.rs +++ b/src/core/utils/rand.rs @@ -3,7 +3,13 @@ time::{Duration, SystemTime}, }; -use rand::{thread_rng, Rng}; +use arrayvec::ArrayString; +use rand::{seq::SliceRandom, thread_rng, Rng}; + +pub fn shuffle<T>(vec: &mut [T]) { + let mut rng = thread_rng(); + vec.shuffle(&mut rng); +} pub fn string(length: usize) -> String { thread_rng() @@ -13,6 +19,18 @@ pub fn string(length: usize) -> String { .collect() } +#[inline] +pub fn string_array<const LENGTH: usize>() -> ArrayString<LENGTH> { + let mut ret = ArrayString::<LENGTH>::new(); + thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(LENGTH) + .map(char::from) + .for_each(|c| ret.push(c)); + + ret +} + #[inline] #[must_use] pub fn timepoint_secs(range: Range<u64>) -> SystemTime { diff --git a/src/core/utils/result.rs b/src/core/utils/result.rs new file mode 100644 index 0000000000000000000000000000000000000000..6b11ea66f0b9e34df47e78a84a7e1f2c762998f2 --- /dev/null +++ b/src/core/utils/result.rs @@ -0,0 +1,18 @@ +mod debug_inspect; +mod filter; +mod flat_ok; +mod into_is_ok; +mod log_debug_err; +mod log_err; +mod map_expect; +mod not_found; +mod unwrap_infallible; +mod unwrap_or_err; + +pub use self::{ + debug_inspect::DebugInspect, filter::Filter, flat_ok::FlatOk, into_is_ok::IntoIsOk, log_debug_err::LogDebugErr, + log_err::LogErr, map_expect::MapExpect, not_found::NotFound, unwrap_infallible::UnwrapInfallible, + unwrap_or_err::UnwrapOrErr, +}; + +pub type Result<T = (), E = crate::Error> = std::result::Result<T, E>; diff --git a/src/core/utils/result/debug_inspect.rs b/src/core/utils/result/debug_inspect.rs new file mode 100644 index 0000000000000000000000000000000000000000..ef80979d8528baa10617b667fb51d91d22b46441 --- /dev/null +++ b/src/core/utils/result/debug_inspect.rs @@ -0,0 +1,52 @@ +use super::Result; + +/// Inspect Result values with release-mode elision. +pub trait DebugInspect<T, E> { + /// Inspects an Err contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect_err<F: FnOnce(&E)>(self, f: F) -> Self; + + /// Inspects an Ok contained value in debug-mode. In release-mode closure F + /// is elided. + #[must_use] + fn debug_inspect<F: FnOnce(&T)>(self, f: F) -> Self; +} + +#[cfg(debug_assertions)] +impl<T, E> DebugInspect<T, E> for Result<T, E> { + #[inline] + fn debug_inspect<F>(self, f: F) -> Self + where + F: FnOnce(&T), + { + self.inspect(f) + } + + #[inline] + fn debug_inspect_err<F>(self, f: F) -> Self + where + F: FnOnce(&E), + { + self.inspect_err(f) + } +} + +#[cfg(not(debug_assertions))] +impl<T, E> DebugInspect<T, E> for Result<T, E> { + #[inline] + fn debug_inspect<F>(self, _: F) -> Self + where + F: FnOnce(&T), + { + self + } + + #[inline] + fn debug_inspect_err<F>(self, _: F) -> Self + where + F: FnOnce(&E), + { + self + } +} diff --git a/src/core/utils/result/filter.rs b/src/core/utils/result/filter.rs new file mode 100644 index 0000000000000000000000000000000000000000..f11d363292d44e45e546476517a324c996d3aa7f --- /dev/null +++ b/src/core/utils/result/filter.rs @@ -0,0 +1,21 @@ +use super::Result; + +pub trait Filter<T, E> { + /// Similar to Option::filter + #[must_use] + fn filter<P, U>(self, predicate: P) -> Self + where + P: FnOnce(&T) -> Result<(), U>, + E: From<U>; +} + +impl<T, E> Filter<T, E> for Result<T, E> { + #[inline] + fn filter<P, U>(self, predicate: P) -> Self + where + P: FnOnce(&T) -> Result<(), U>, + E: From<U>, + { + self.and_then(move |t| predicate(&t).map(move |()| t).map_err(Into::into)) + } +} diff --git a/src/core/utils/result/flat_ok.rs b/src/core/utils/result/flat_ok.rs new file mode 100644 index 0000000000000000000000000000000000000000..e378e5d05e32ad4eecab27969fbffc7532b31e39 --- /dev/null +++ b/src/core/utils/result/flat_ok.rs @@ -0,0 +1,34 @@ +use super::Result; + +pub trait FlatOk<T> { + /// Equivalent to .transpose().ok().flatten() + fn flat_ok(self) -> Option<T>; + + /// Equivalent to .transpose().ok().flatten().ok_or(...) + fn flat_ok_or<E>(self, err: E) -> Result<T, E>; + + /// Equivalent to .transpose().ok().flatten().ok_or_else(...) + fn flat_ok_or_else<E, F: FnOnce() -> E>(self, err: F) -> Result<T, E>; +} + +impl<T, E> FlatOk<T> for Option<Result<T, E>> { + #[inline] + fn flat_ok(self) -> Option<T> { self.transpose().ok().flatten() } + + #[inline] + fn flat_ok_or<Ep>(self, err: Ep) -> Result<T, Ep> { self.flat_ok().ok_or(err) } + + #[inline] + fn flat_ok_or_else<Ep, F: FnOnce() -> Ep>(self, err: F) -> Result<T, Ep> { self.flat_ok().ok_or_else(err) } +} + +impl<T, E> FlatOk<T> for Result<Option<T>, E> { + #[inline] + fn flat_ok(self) -> Option<T> { self.ok().flatten() } + + #[inline] + fn flat_ok_or<Ep>(self, err: Ep) -> Result<T, Ep> { self.flat_ok().ok_or(err) } + + #[inline] + fn flat_ok_or_else<Ep, F: FnOnce() -> Ep>(self, err: F) -> Result<T, Ep> { self.flat_ok().ok_or_else(err) } +} diff --git a/src/core/utils/result/inspect_log.rs b/src/core/utils/result/inspect_log.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9f32663c2b82647012b4d176d9573dcd535613e --- /dev/null +++ b/src/core/utils/result/inspect_log.rs @@ -0,0 +1,62 @@ +use std::fmt; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait ErrLog<T, E> +where + E: fmt::Display, +{ + fn log_err(self, level: Level) -> Self; + + #[inline] + fn err_log(self) -> Self + where + Self: Sized, + { + self.log_err(Level::ERROR) + } +} + +pub trait ErrDebugLog<T, E> +where + E: fmt::Debug, +{ + fn log_err_debug(self, level: Level) -> Self; + + #[inline] + fn err_debug_log(self) -> Self + where + Self: Sized, + { + self.log_err_debug(Level::ERROR) + } +} + +impl<T, E> ErrLog<T, E> for Result<T, E> +where + E: fmt::Display, +{ + #[inline] + fn log_err(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_log_level(&error, level)) + } +} + +impl<T, E> ErrDebugLog<T, E> for Result<T, E> +where + E: fmt::Debug, +{ + #[inline] + fn log_err_debug(self, level: Level) -> Self + where + Self: Sized, + { + self.inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/utils/result/into_is_ok.rs b/src/core/utils/result/into_is_ok.rs new file mode 100644 index 0000000000000000000000000000000000000000..220ce010c5e2b2098313ab0f041985a3ac5bc262 --- /dev/null +++ b/src/core/utils/result/into_is_ok.rs @@ -0,0 +1,10 @@ +use super::Result; + +pub trait IntoIsOk<T, E> { + fn into_is_ok(self) -> bool; +} + +impl<T, E> IntoIsOk<T, E> for Result<T, E> { + #[inline] + fn into_is_ok(self) -> bool { self.is_ok() } +} diff --git a/src/core/utils/result/log_debug_err.rs b/src/core/utils/result/log_debug_err.rs new file mode 100644 index 0000000000000000000000000000000000000000..8835afd1943fb0148f895a43f694600daa3024dc --- /dev/null +++ b/src/core/utils/result/log_debug_err.rs @@ -0,0 +1,26 @@ +use std::fmt::Debug; + +use tracing::Level; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait LogDebugErr<T, E: Debug> { + #[must_use] + fn err_debug_log(self, level: Level) -> Self; + + #[must_use] + fn log_debug_err(self) -> Self + where + Self: Sized, + { + self.err_debug_log(Level::ERROR) + } +} + +impl<T, E: Debug> LogDebugErr<T, E> for Result<T, E> { + #[inline] + fn err_debug_log(self, level: Level) -> Self { + self.debug_inspect_err(|error| error::inspect_debug_log_level(&error, level)) + } +} diff --git a/src/core/utils/result/log_err.rs b/src/core/utils/result/log_err.rs new file mode 100644 index 0000000000000000000000000000000000000000..374a5e596e21a77464e170e479bb43bf73492cc3 --- /dev/null +++ b/src/core/utils/result/log_err.rs @@ -0,0 +1,24 @@ +use std::fmt::Display; + +use tracing::Level; + +use super::Result; +use crate::error; + +pub trait LogErr<T, E: Display> { + #[must_use] + fn err_log(self, level: Level) -> Self; + + #[must_use] + fn log_err(self) -> Self + where + Self: Sized, + { + self.err_log(Level::ERROR) + } +} + +impl<T, E: Display> LogErr<T, E> for Result<T, E> { + #[inline] + fn err_log(self, level: Level) -> Self { self.inspect_err(|error| error::inspect_log_level(&error, level)) } +} diff --git a/src/core/utils/result/map_expect.rs b/src/core/utils/result/map_expect.rs new file mode 100644 index 0000000000000000000000000000000000000000..9cd498f7fbdd1c938dc273f9cdd687b4d49a4259 --- /dev/null +++ b/src/core/utils/result/map_expect.rs @@ -0,0 +1,15 @@ +use std::fmt::Debug; + +use super::Result; + +pub trait MapExpect<'a, T> { + /// Calls expect(msg) on the mapped Result value. This is similar to + /// map(Result::unwrap) but composes an expect call and message without + /// requiring a closure. + fn map_expect(self, msg: &'a str) -> T; +} + +impl<'a, T, E: Debug> MapExpect<'a, Option<T>> for Option<Result<T, E>> { + #[inline] + fn map_expect(self, msg: &'a str) -> Option<T> { self.map(|result| result.expect(msg)) } +} diff --git a/src/core/utils/result/not_found.rs b/src/core/utils/result/not_found.rs new file mode 100644 index 0000000000000000000000000000000000000000..d61825afa6a9ba6a3ebdcac1124ab39aea2e9ba8 --- /dev/null +++ b/src/core/utils/result/not_found.rs @@ -0,0 +1,12 @@ +use super::Result; +use crate::Error; + +pub trait NotFound<T> { + #[must_use] + fn is_not_found(&self) -> bool; +} + +impl<T> NotFound<T> for Result<T, Error> { + #[inline] + fn is_not_found(&self) -> bool { self.as_ref().is_err_and(Error::is_not_found) } +} diff --git a/src/core/utils/result/unwrap_infallible.rs b/src/core/utils/result/unwrap_infallible.rs new file mode 100644 index 0000000000000000000000000000000000000000..99309e02551d3a6f7c2881559b8030b1ec45e8ad --- /dev/null +++ b/src/core/utils/result/unwrap_infallible.rs @@ -0,0 +1,17 @@ +use std::convert::Infallible; + +use super::{DebugInspect, Result}; +use crate::error; + +pub trait UnwrapInfallible<T> { + fn unwrap_infallible(self) -> T; +} + +impl<T> UnwrapInfallible<T> for Result<T, Infallible> { + #[inline] + fn unwrap_infallible(self) -> T { + // SAFETY: Branchless unwrap for errors that can never happen. In debug + // mode this is asserted. + unsafe { self.debug_inspect_err(error::infallible).unwrap_unchecked() } + } +} diff --git a/src/core/utils/result/unwrap_or_err.rs b/src/core/utils/result/unwrap_or_err.rs new file mode 100644 index 0000000000000000000000000000000000000000..69901958f3dec8dee6b006a6823ac7c39f022a0f --- /dev/null +++ b/src/core/utils/result/unwrap_or_err.rs @@ -0,0 +1,15 @@ +use std::convert::identity; + +use super::Result; + +/// Returns the Ok value or the Err value. Available when the Ok and Err types +/// are the same. This is a way to default the result using the specific Err +/// value rather than unwrap_or_default() using Ok's default. +pub trait UnwrapOrErr<T> { + fn unwrap_or_err(self) -> T; +} + +impl<T> UnwrapOrErr<T> for Result<T, T> { + #[inline] + fn unwrap_or_err(self) -> T { self.unwrap_or_else(identity::<T>) } +} diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs new file mode 100644 index 0000000000000000000000000000000000000000..563f9df5ceeb707fccc99e4235e859a99f188de7 --- /dev/null +++ b/src/core/utils/set.rs @@ -0,0 +1,47 @@ +use std::cmp::{Eq, Ord}; + +use crate::{is_equal_to, is_less_than}; + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs do not have to +/// be sorted. If inputs are sorted a more optimized function is available in +/// this suite and should be used. +pub fn intersection<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send +where + Iters: Iterator<Item = Iter> + Clone + Send, + Iter: Iterator<Item = Item> + Send, + Item: Eq + Send, +{ + input.next().into_iter().flat_map(move |first| { + let input = input.clone(); + first.filter(move |targ| { + input + .clone() + .all(|mut other| other.any(is_equal_to!(*targ))) + }) + }) +} + +/// Intersection of sets +/// +/// Outputs the set of elements common to all input sets. Inputs must be sorted. +pub fn intersection_sorted<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send +where + Iters: Iterator<Item = Iter> + Clone + Send, + Iter: Iterator<Item = Item> + Send, + Item: Eq + Ord + Send, +{ + input.next().into_iter().flat_map(move |first| { + let mut input = input.clone().collect::<Vec<_>>(); + first.filter(move |targ| { + input.iter_mut().all(|it| { + it.by_ref() + .skip_while(is_less_than!(targ)) + .peekable() + .peek() + .is_some_and(is_equal_to!(targ)) + }) + }) + }) +} diff --git a/src/core/utils/stream/cloned.rs b/src/core/utils/stream/cloned.rs new file mode 100644 index 0000000000000000000000000000000000000000..d6a0e6470eebe4c2d54fe051772805519ea41fd3 --- /dev/null +++ b/src/core/utils/stream/cloned.rs @@ -0,0 +1,20 @@ +use std::clone::Clone; + +use futures::{stream::Map, Stream, StreamExt}; + +pub trait Cloned<'a, T, S> +where + S: Stream<Item = &'a T>, + T: Clone + 'a, +{ + fn cloned(self) -> Map<S, fn(&T) -> T>; +} + +impl<'a, T, S> Cloned<'a, T, S> for S +where + S: Stream<Item = &'a T>, + T: Clone + 'a, +{ + #[inline] + fn cloned(self) -> Map<S, fn(&T) -> T> { self.map(Clone::clone) } +} diff --git a/src/core/utils/stream/expect.rs b/src/core/utils/stream/expect.rs new file mode 100644 index 0000000000000000000000000000000000000000..68ac24cedd501a249071fcefd62fbf032d623920 --- /dev/null +++ b/src/core/utils/stream/expect.rs @@ -0,0 +1,22 @@ +use futures::{Stream, StreamExt, TryStream}; + +use crate::Result; + +pub trait TryExpect<'a, Item> { + fn expect_ok(self) -> impl Stream<Item = Item> + Send + 'a; + + fn map_expect(self, msg: &'a str) -> impl Stream<Item = Item> + Send + 'a; +} + +impl<'a, T, Item> TryExpect<'a, Item> for T +where + T: Stream<Item = Result<Item>> + TryStream + Send + 'a, + Item: 'a, +{ + #[inline] + fn expect_ok(self: T) -> impl Stream<Item = Item> + Send + 'a { self.map_expect("stream expectation failure") } + + //TODO: move to impl MapExpect + #[inline] + fn map_expect(self, msg: &'a str) -> impl Stream<Item = Item> + Send + 'a { self.map(|res| res.expect(msg)) } +} diff --git a/src/core/utils/stream/ignore.rs b/src/core/utils/stream/ignore.rs new file mode 100644 index 0000000000000000000000000000000000000000..997aa4ba44c57c36fe4515a1f226eb9406a9b4e7 --- /dev/null +++ b/src/core/utils/stream/ignore.rs @@ -0,0 +1,21 @@ +use futures::{future::ready, Stream, StreamExt, TryStream}; + +use crate::{Error, Result}; + +pub trait TryIgnore<'a, Item> { + fn ignore_err(self) -> impl Stream<Item = Item> + Send + 'a; + + fn ignore_ok(self) -> impl Stream<Item = Error> + Send + 'a; +} + +impl<'a, T, Item> TryIgnore<'a, Item> for T +where + T: Stream<Item = Result<Item>> + TryStream + Send + 'a, + Item: Send + 'a, +{ + #[inline] + fn ignore_err(self: T) -> impl Stream<Item = Item> + Send + 'a { self.filter_map(|res| ready(res.ok())) } + + #[inline] + fn ignore_ok(self: T) -> impl Stream<Item = Error> + Send + 'a { self.filter_map(|res| ready(res.err())) } +} diff --git a/src/core/utils/stream/iter_stream.rs b/src/core/utils/stream/iter_stream.rs new file mode 100644 index 0000000000000000000000000000000000000000..69edf64f5385d3ec95674e10568c00c0412a588c --- /dev/null +++ b/src/core/utils/stream/iter_stream.rs @@ -0,0 +1,27 @@ +use futures::{ + stream, + stream::{Stream, TryStream}, + StreamExt, +}; + +pub trait IterStream<I: IntoIterator + Send> { + /// Convert an Iterator into a Stream + fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send; + + /// Convert an Iterator into a TryStream + fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send; +} + +impl<I> IterStream<I> for I +where + I: IntoIterator + Send, + <I as IntoIterator>::IntoIter: Send, +{ + #[inline] + fn stream(self) -> impl Stream<Item = <I as IntoIterator>::Item> + Send { stream::iter(self) } + + #[inline] + fn try_stream(self) -> impl TryStream<Ok = <I as IntoIterator>::Item, Error = crate::Error> + Send { + self.stream().map(Ok) + } +} diff --git a/src/core/utils/stream/mod.rs b/src/core/utils/stream/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..1111915b3f4caeaea62db8efbc122f8b38f66557 --- /dev/null +++ b/src/core/utils/stream/mod.rs @@ -0,0 +1,15 @@ +mod cloned; +mod expect; +mod ignore; +mod iter_stream; +mod ready; +mod tools; +mod try_ready; + +pub use cloned::Cloned; +pub use expect::TryExpect; +pub use ignore::TryIgnore; +pub use iter_stream::IterStream; +pub use ready::ReadyExt; +pub use tools::Tools; +pub use try_ready::TryReadyExt; diff --git a/src/core/utils/stream/ready.rs b/src/core/utils/stream/ready.rs new file mode 100644 index 0000000000000000000000000000000000000000..f4eec7d1b3271f39301843939531a2ab3b500d6d --- /dev/null +++ b/src/core/utils/stream/ready.rs @@ -0,0 +1,157 @@ +//! Synchronous combinator extensions to futures::Stream +#![allow(clippy::type_complexity)] + +use futures::{ + future::{ready, Ready}, + stream::{Any, Filter, FilterMap, Fold, ForEach, Scan, SkipWhile, Stream, StreamExt, TakeWhile}, +}; + +/// Synchronous combinators to augment futures::StreamExt. Most Stream +/// combinators take asynchronous arguments, but often only simple predicates +/// are required to steer a Stream like an Iterator. This suite provides a +/// convenience to reduce boilerplate by de-cluttering non-async predicates. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait ReadyExt<Item> +where + Self: Stream<Item = Item> + Send + Sized, +{ + fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(Item) -> Ready<bool>> + where + F: Fn(Item) -> bool; + + fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a; + + fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(Item) -> Ready<Option<U>>> + where + F: Fn(Item) -> Option<U>; + + fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, Item) -> Ready<T>> + where + F: Fn(T, Item) -> T; + + fn ready_fold_default<T, F>(self, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, Item) -> Ready<T>> + where + F: Fn(T, Item) -> T, + T: Default; + + fn ready_for_each<F>(self, f: F) -> ForEach<Self, Ready<()>, impl FnMut(Item) -> Ready<()>> + where + F: FnMut(Item); + + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a; + + fn ready_scan<B, T, F>( + self, init: T, f: F, + ) -> Scan<Self, T, Ready<Option<B>>, impl FnMut(&mut T, Item) -> Ready<Option<B>>> + where + F: Fn(&mut T, Item) -> Option<B>; + + fn ready_scan_each<T, F>( + self, init: T, f: F, + ) -> Scan<Self, T, Ready<Option<Item>>, impl FnMut(&mut T, Item) -> Ready<Option<Item>>> + where + F: Fn(&mut T, &Item); + + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a; +} + +impl<Item, S> ReadyExt<Item> for S +where + S: Stream<Item = Item> + Send + Sized, +{ + #[inline] + fn ready_any<F>(self, f: F) -> Any<Self, Ready<bool>, impl FnMut(Item) -> Ready<bool>> + where + F: Fn(Item) -> bool, + { + self.any(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter<'a, F>(self, f: F) -> Filter<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.filter(move |t| ready(f(t))) + } + + #[inline] + fn ready_filter_map<F, U>(self, f: F) -> FilterMap<Self, Ready<Option<U>>, impl FnMut(Item) -> Ready<Option<U>>> + where + F: Fn(Item) -> Option<U>, + { + self.filter_map(move |t| ready(f(t))) + } + + #[inline] + fn ready_fold<T, F>(self, init: T, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, Item) -> Ready<T>> + where + F: Fn(T, Item) -> T, + { + self.fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + fn ready_fold_default<T, F>(self, f: F) -> Fold<Self, Ready<T>, T, impl FnMut(T, Item) -> Ready<T>> + where + F: Fn(T, Item) -> T, + T: Default, + { + self.ready_fold(T::default(), f) + } + + #[inline] + #[allow(clippy::unit_arg)] + fn ready_for_each<F>(self, mut f: F) -> ForEach<Self, Ready<()>, impl FnMut(Item) -> Ready<()>> + where + F: FnMut(Item), + { + self.for_each(move |t| ready(f(t))) + } + + #[inline] + fn ready_take_while<'a, F>(self, f: F) -> TakeWhile<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.take_while(move |t| ready(f(t))) + } + + #[inline] + fn ready_scan<B, T, F>( + self, init: T, f: F, + ) -> Scan<Self, T, Ready<Option<B>>, impl FnMut(&mut T, Item) -> Ready<Option<B>>> + where + F: Fn(&mut T, Item) -> Option<B>, + { + self.scan(init, move |s, t| ready(f(s, t))) + } + + #[inline] + fn ready_scan_each<T, F>( + self, init: T, f: F, + ) -> Scan<Self, T, Ready<Option<Item>>, impl FnMut(&mut T, Item) -> Ready<Option<Item>>> + where + F: Fn(&mut T, &Item), + { + self.ready_scan(init, move |s, t| { + f(s, &t); + Some(t) + }) + } + + #[inline] + fn ready_skip_while<'a, F>(self, f: F) -> SkipWhile<Self, Ready<bool>, impl FnMut(&Item) -> Ready<bool> + 'a> + where + F: Fn(&Item) -> bool + 'a, + { + self.skip_while(move |t| ready(f(t))) + } +} diff --git a/src/core/utils/stream/tools.rs b/src/core/utils/stream/tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..cc6b7ca9e712788f1076ab1d4ab883cf1ab6976a --- /dev/null +++ b/src/core/utils/stream/tools.rs @@ -0,0 +1,80 @@ +//! StreamTools for futures::Stream + +use std::{collections::HashMap, hash::Hash}; + +use futures::{Future, Stream, StreamExt}; + +use super::ReadyExt; +use crate::expected; + +/// StreamTools +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait Tools<Item> +where + Self: Stream<Item = Item> + Send + Sized, + <Self as Stream>::Item: Send, +{ + fn counts(self) -> impl Future<Output = HashMap<Item, usize>> + Send + where + <Self as Stream>::Item: Eq + Hash; + + fn counts_by<K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_by_with_cap<const CAP: usize, K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send; + + fn counts_with_cap<const CAP: usize>(self) -> impl Future<Output = HashMap<Item, usize>> + Send + where + <Self as Stream>::Item: Eq + Hash; +} + +impl<Item, S> Tools<Item> for S +where + S: Stream<Item = Item> + Send + Sized, + <Self as Stream>::Item: Send, +{ + #[inline] + fn counts(self) -> impl Future<Output = HashMap<Item, usize>> + Send + where + <Self as Stream>::Item: Eq + Hash, + { + self.counts_with_cap::<0>() + } + + #[inline] + fn counts_by<K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.counts_by_with_cap::<0, K, F>(f) + } + + #[inline] + fn counts_by_with_cap<const CAP: usize, K, F>(self, f: F) -> impl Future<Output = HashMap<K, usize>> + Send + where + F: Fn(Item) -> K + Send, + K: Eq + Hash + Send, + { + self.map(f).counts_with_cap::<CAP>() + } + + #[inline] + fn counts_with_cap<const CAP: usize>(self) -> impl Future<Output = HashMap<Item, usize>> + Send + where + <Self as Stream>::Item: Eq + Hash, + { + self.ready_fold(HashMap::with_capacity(CAP), |mut counts, item| { + let entry = counts.entry(item).or_default(); + let value = *entry; + *entry = expected!(value + 1); + counts + }) + } +} diff --git a/src/core/utils/stream/try_ready.rs b/src/core/utils/stream/try_ready.rs new file mode 100644 index 0000000000000000000000000000000000000000..0daed26e49424a37133b658c3286ddd884fc66e1 --- /dev/null +++ b/src/core/utils/stream/try_ready.rs @@ -0,0 +1,86 @@ +//! Synchronous combinator extensions to futures::TryStream +#![allow(clippy::type_complexity)] + +use futures::{ + future::{ready, Ready}, + stream::{AndThen, TryFold, TryForEach, TryStream, TryStreamExt}, +}; + +use crate::Result; + +/// Synchronous combinators to augment futures::TryStreamExt. +/// +/// This interface is not necessarily complete; feel free to add as-needed. +pub trait TryReadyExt<T, E, S> +where + S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(S::Ok) -> Result<U, E>; + + fn ready_try_for_each<F>( + self, f: F, + ) -> TryForEach<Self, Ready<Result<(), E>>, impl FnMut(S::Ok) -> Ready<Result<(), E>>> + where + F: FnMut(S::Ok) -> Result<(), E>; + + fn ready_try_fold<U, F>( + self, init: U, f: F, + ) -> TryFold<Self, Ready<Result<U, E>>, U, impl FnMut(U, S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(U, S::Ok) -> Result<U, E>; + + fn ready_try_fold_default<U, F>( + self, f: F, + ) -> TryFold<Self, Ready<Result<U, E>>, U, impl FnMut(U, S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(U, S::Ok) -> Result<U, E>, + U: Default; +} + +impl<T, E, S> TryReadyExt<T, E, S> for S +where + S: TryStream<Ok = T, Error = E, Item = Result<T, E>> + Send + ?Sized, + Self: TryStream + Send + Sized, +{ + #[inline] + fn ready_and_then<U, F>(self, f: F) -> AndThen<Self, Ready<Result<U, E>>, impl FnMut(S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(S::Ok) -> Result<U, E>, + { + self.and_then(move |t| ready(f(t))) + } + + #[inline] + fn ready_try_for_each<F>( + self, mut f: F, + ) -> TryForEach<Self, Ready<Result<(), E>>, impl FnMut(S::Ok) -> Ready<Result<(), E>>> + where + F: FnMut(S::Ok) -> Result<(), E>, + { + self.try_for_each(move |t| ready(f(t))) + } + + #[inline] + fn ready_try_fold<U, F>( + self, init: U, f: F, + ) -> TryFold<Self, Ready<Result<U, E>>, U, impl FnMut(U, S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(U, S::Ok) -> Result<U, E>, + { + self.try_fold(init, move |a, t| ready(f(a, t))) + } + + #[inline] + fn ready_try_fold_default<U, F>( + self, f: F, + ) -> TryFold<Self, Ready<Result<U, E>>, U, impl FnMut(U, S::Ok) -> Ready<Result<U, E>>> + where + F: Fn(U, S::Ok) -> Result<U, E>, + U: Default, + { + self.ready_try_fold(U::default(), f) + } +} diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 85282b30aa8b7e07468f31d1092ebb14ed6e4b03..e65a3369837abf529e8dc0010e166317909c8326 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -1,3 +1,10 @@ +mod between; +mod split; +mod tests; +mod unquote; +mod unquoted; + +pub use self::{between::Between, split::SplitInfallible, unquote::Unquote, unquoted::Unquoted}; use crate::{utils::exchange, Result}; pub const EMPTY: &str = ""; @@ -95,12 +102,6 @@ pub fn common_prefix<'a>(choice: &'a [&str]) -> &'a str { }) } -#[inline] -#[must_use] -pub fn split_once_infallible<'a>(input: &'a str, delim: &'_ str) -> (&'a str, &'a str) { - input.split_once(delim).unwrap_or((input, EMPTY)) -} - /// Parses the bytes into a string. pub fn string_from_bytes(bytes: &[u8]) -> Result<String> { let str: &str = str_from_bytes(bytes)?; diff --git a/src/core/utils/string/between.rs b/src/core/utils/string/between.rs new file mode 100644 index 0000000000000000000000000000000000000000..209a9dabb2fba76278f34a54d663034855be912b --- /dev/null +++ b/src/core/utils/string/between.rs @@ -0,0 +1,26 @@ +type Delim<'a> = (&'a str, &'a str); + +/// Slice a string between a pair of delimeters. +pub trait Between<'a> { + /// Extract a string between the delimeters. If the delimeters were not + /// found None is returned, otherwise the first extraction is returned. + fn between(&self, delim: Delim<'_>) -> Option<&'a str>; + + /// Extract a string between the delimeters. If the delimeters were not + /// found the original string is returned; take note of this behavior, + /// if an empty slice is desired for this case use the fallible version and + /// unwrap to EMPTY. + fn between_infallible(&self, delim: Delim<'_>) -> &'a str; +} + +impl<'a> Between<'a> for &'a str { + #[inline] + fn between_infallible(&self, delim: Delim<'_>) -> &'a str { self.between(delim).unwrap_or(self) } + + #[inline] + fn between(&self, delim: Delim<'_>) -> Option<&'a str> { + self.split_once(delim.0) + .and_then(|(_, b)| b.rsplit_once(delim.1)) + .map(|(a, _)| a) + } +} diff --git a/src/core/utils/string/split.rs b/src/core/utils/string/split.rs new file mode 100644 index 0000000000000000000000000000000000000000..96de28dff575583ee4e9c648d175a5bfece03056 --- /dev/null +++ b/src/core/utils/string/split.rs @@ -0,0 +1,22 @@ +use super::EMPTY; + +type Pair<'a> = (&'a str, &'a str); + +/// Split a string with default behaviors on non-match. +pub trait SplitInfallible<'a> { + /// Split a string at the first occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn split_once_infallible(&self, delim: &str) -> Pair<'a>; + + /// Split a string from the last occurrence of delim. If not found, the + /// entire string is returned in \[0\], while \[1\] is empty. + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a>; +} + +impl<'a> SplitInfallible<'a> for &'a str { + #[inline] + fn rsplit_once_infallible(&self, delim: &str) -> Pair<'a> { self.rsplit_once(delim).unwrap_or((self, EMPTY)) } + + #[inline] + fn split_once_infallible(&self, delim: &str) -> Pair<'a> { self.split_once(delim).unwrap_or((self, EMPTY)) } +} diff --git a/src/core/utils/string/tests.rs b/src/core/utils/string/tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..e8c17de6d9749a010e4944098755998d4a1cd96c --- /dev/null +++ b/src/core/utils/string/tests.rs @@ -0,0 +1,70 @@ +#![cfg(test)] + +#[test] +fn common_prefix() { + let input = ["conduwuit", "conduit", "construct"]; + let output = super::common_prefix(&input); + assert_eq!(output, "con"); +} + +#[test] +fn common_prefix_empty() { + let input = ["abcdefg", "hijklmn", "opqrstu"]; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn common_prefix_none() { + let input = []; + let output = super::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn camel_to_snake_case_0() { + let res = super::camel_to_snake_string("CamelToSnakeCase"); + assert_eq!(res, "camel_to_snake_case"); +} + +#[test] +fn camel_to_snake_case_1() { + let res = super::camel_to_snake_string("CAmelTOSnakeCase"); + assert_eq!(res, "camel_tosnake_case"); +} + +#[test] +fn unquote() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote(), Some("foo")); + assert_eq!("\"foo".unquote(), None); + assert_eq!("foo".unquote(), None); +} + +#[test] +fn unquote_infallible() { + use super::Unquote; + + assert_eq!("\"foo\"".unquote_infallible(), "foo"); + assert_eq!("\"foo".unquote_infallible(), "\"foo"); + assert_eq!("foo".unquote_infallible(), "foo"); +} + +#[test] +fn between() { + use super::Between; + + assert_eq!("\"foo\"".between(("\"", "\"")), Some("foo")); + assert_eq!("\"foo".between(("\"", "\"")), None); + assert_eq!("foo".between(("\"", "\"")), None); +} + +#[test] +fn between_infallible() { + use super::Between; + + assert_eq!("\"foo\"".between_infallible(("\"", "\"")), "foo"); + assert_eq!("\"foo".between_infallible(("\"", "\"")), "\"foo"); + assert_eq!("foo".between_infallible(("\"", "\"")), "foo"); +} diff --git a/src/core/utils/string/unquote.rs b/src/core/utils/string/unquote.rs new file mode 100644 index 0000000000000000000000000000000000000000..eeded610ab8aec958fbb9999d5259605a197ae1c --- /dev/null +++ b/src/core/utils/string/unquote.rs @@ -0,0 +1,33 @@ +const QUOTE: char = '"'; + +/// Slice a string between quotes +pub trait Unquote<'a> { + /// Whether the input is quoted. If this is false the fallible methods of + /// this interface will fail. + fn is_quoted(&self) -> bool; + + /// Unquotes a string. If the input is not quoted it is simply returned + /// as-is. If the input is partially quoted on either end that quote is not + /// removed. + fn unquote(&self) -> Option<&'a str>; + + /// Unquotes a string. The input must be quoted on each side for Some to be + /// returned + fn unquote_infallible(&self) -> &'a str; +} + +impl<'a> Unquote<'a> for &'a str { + #[inline] + fn unquote_infallible(&self) -> &'a str { + self.strip_prefix(QUOTE) + .unwrap_or(self) + .strip_suffix(QUOTE) + .unwrap_or(self) + } + + #[inline] + fn unquote(&self) -> Option<&'a str> { self.strip_prefix(QUOTE).and_then(|s| s.strip_suffix(QUOTE)) } + + #[inline] + fn is_quoted(&self) -> bool { self.starts_with(QUOTE) && self.ends_with(QUOTE) } +} diff --git a/src/core/utils/string/unquoted.rs b/src/core/utils/string/unquoted.rs new file mode 100644 index 0000000000000000000000000000000000000000..5b002d99b6cc385bd26ddf9f782740c05a99da54 --- /dev/null +++ b/src/core/utils/string/unquoted.rs @@ -0,0 +1,52 @@ +use std::ops::Deref; + +use serde::{de, Deserialize, Deserializer}; + +use super::Unquote; +use crate::{err, Result}; + +/// Unquoted string which deserialized from a quoted string. Construction from a +/// &str is infallible such that the input can already be unquoted. Construction +/// from serde deserialization is fallible and the input must be quoted. +#[repr(transparent)] +pub struct Unquoted(str); + +impl<'a> Unquoted { + #[inline] + #[must_use] + pub fn as_str(&'a self) -> &'a str { &self.0 } +} + +impl<'a, 'de: 'a> Deserialize<'de> for &'a Unquoted { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { + let s = <&'a str>::deserialize(deserializer)?; + s.is_quoted() + .then_some(s) + .ok_or(err!(SerdeDe("expected quoted string"))) + .map_err(de::Error::custom) + .map(Into::into) + } +} + +impl<'a> From<&'a str> for &'a Unquoted { + fn from(s: &'a str) -> &'a Unquoted { + let s: &'a str = s.unquote_infallible(); + + //SAFETY: This is a pattern I lifted from ruma-identifiers for strong-type strs + // by wrapping in a tuple-struct. + #[allow(clippy::transmute_ptr_to_ptr)] + unsafe { + std::mem::transmute(s) + } + } +} + +impl Deref for Unquoted { + type Target = str; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl<'a> AsRef<str> for &'a Unquoted { + fn as_ref(&self) -> &'a str { &self.0 } +} diff --git a/src/core/utils/sys.rs b/src/core/utils/sys.rs index 6c396921ce116a1d5acd193eb22998282e9c2d3b..af8bd70b736cc50614acd0807299990f2347a4a2 100644 --- a/src/core/utils/sys.rs +++ b/src/core/utils/sys.rs @@ -1,6 +1,4 @@ -use tracing::debug; - -use crate::Result; +use crate::{debug, Result}; /// This is needed for opening lots of file descriptors, which tends to /// happen more often when using RocksDB and making lots of federation diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index e91accdf49fcc77c3910177ce5a5e68892964e79..84d35936ebfc5482fdcff33956f2a0cda66afca8 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -36,33 +36,6 @@ fn increment_wrap() { assert_eq!(res, 0); } -#[test] -fn common_prefix() { - use utils::string; - - let input = ["conduwuit", "conduit", "construct"]; - let output = string::common_prefix(&input); - assert_eq!(output, "con"); -} - -#[test] -fn common_prefix_empty() { - use utils::string; - - let input = ["abcdefg", "hijklmn", "opqrstu"]; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - -#[test] -fn common_prefix_none() { - use utils::string; - - let input = []; - let output = string::common_prefix(&input); - assert_eq!(output, ""); -} - #[test] fn checked_add() { use crate::checked; @@ -136,17 +109,131 @@ async fn mutex_map_contend() { } #[test] -fn camel_to_snake_case_0() { - use utils::string::camel_to_snake_string; +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_none() { + use utils::set::intersection; + + let a: [&str; 0] = []; + let b: [&str; 0] = []; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a: [&str; 0] = []; + let b = ["abc", "def"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + let i = [a.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); + + let a = ["foo", "bar", "baz"]; + let b = ["def", "hij", "klm", "nop"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert_eq!(r.count(), 0); +} - let res = camel_to_snake_string("CamelToSnakeCase"); - assert_eq!(res, "camel_to_snake_case"); +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_all() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["foo", "bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar"].iter())); + let i = [b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["foo", "bar", "baz"]; + let b = ["baz", "foo", "bar"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "bar", "baz"].iter())); } #[test] -fn camel_to_snake_case_1() { - use utils::string::camel_to_snake_string; +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_some() { + use utils::set::intersection; + + let a = ["foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["abcdef", "foo", "hijkl", "abc"]; + let b = ["hij", "bar", "baz", "abc", "foo"]; + let c = ["abc", "xyz", "foo", "ghi"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection(i.into_iter()); + assert!(r.eq(["foo", "abc"].iter())); +} - let res = camel_to_snake_string("CAmelTOSnakeCase"); - assert_eq!(res, "camel_tosnake_case"); +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_some() { + use utils::set::intersection_sorted; + + let a = ["bar"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + let i = [b.iter(), a.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar"].iter())); + + let a = ["aaa", "ccc", "eee", "ggg"]; + let b = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let c = ["bbb", "ccc", "eee", "fff"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["ccc", "eee"].iter())); +} + +#[test] +#[allow(clippy::iter_on_single_items, clippy::many_single_char_names)] +fn set_intersection_sorted_all() { + use utils::set::intersection_sorted; + + let a = ["foo"]; + let b = ["foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["foo"].iter())); + + let a = ["bar", "foo"]; + let b = ["bar", "foo"]; + let i = [a.iter(), b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + let i = [b.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "foo"].iter())); + + let a = ["bar", "baz", "foo"]; + let b = ["bar", "baz", "foo"]; + let c = ["bar", "baz", "foo"]; + let i = [a.iter(), b.iter(), c.iter()]; + let r = intersection_sorted(i.into_iter()); + assert!(r.eq(["bar", "baz", "foo"].iter())); } diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index 04f47ac38098fef90db6b48ee9c0c85fa3bd4d44..f96a27d0092456ca37778854c272e8bceb35c9b9 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -22,6 +22,13 @@ pub fn timepoint_ago(duration: Duration) -> Result<SystemTime> { .ok_or_else(|| err!(Arithmetic("Duration {duration:?} is too large"))) } +#[inline] +pub fn timepoint_from_now(duration: Duration) -> Result<SystemTime> { + SystemTime::now() + .checked_add(duration) + .ok_or_else(|| err!(Arithmetic("Duration {duration:?} is too large"))) +} + #[inline] pub fn parse_duration(duration: &str) -> Result<Duration> { cyborgtime::parse_duration(duration) diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 34d98416dad0b00de130647a2a25df7cad3d70bc..0e718aa717ccbd66a78a39e3689d4ad0f07c5cfa 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -35,10 +35,14 @@ zstd_compression = [ ] [dependencies] +arrayvec.workspace = true conduit-core.workspace = true const-str.workspace = true +futures.workspace = true log.workspace = true rust-rocksdb.workspace = true +serde.workspace = true +serde_json.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/cork.rs b/src/database/cork.rs index 26c520a2800ac2ee755523a11a416b5ac0fdb1a4..5fe5fd7ab7e1c2acadfe4d9ecccd9bdb046902a1 100644 --- a/src/database/cork.rs +++ b/src/database/cork.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::Engine; +use crate::{Database, Engine}; pub struct Cork { db: Arc<Engine>, @@ -8,6 +8,20 @@ pub struct Cork { sync: bool, } +impl Database { + #[inline] + #[must_use] + pub fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } + + #[inline] + #[must_use] + pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } + + #[inline] + #[must_use] + pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } +} + impl Cork { #[inline] pub(super) fn new(db: &Arc<Engine>, flush: bool, sync: bool) -> Self { diff --git a/src/database/database.rs b/src/database/database.rs index c357d50f2978b491a1fbee5dce3323763e24acb3..40aec31235ea79634743efd8fe67296e0bdd0319 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,9 +1,8 @@ use std::{ops::Index, sync::Arc}; -use conduit::{Result, Server}; +use conduit::{err, Result, Server}; use crate::{ - cork::Cork, maps, maps::{Maps, MapsKey, MapsVal}, Engine, Map, @@ -11,7 +10,7 @@ pub struct Database { pub db: Arc<Engine>, - map: Maps, + maps: Maps, } impl Database { @@ -20,31 +19,37 @@ pub async fn open(server: &Arc<Server>) -> Result<Arc<Self>> { let db = Engine::open(server)?; Ok(Arc::new(Self { db: db.clone(), - map: maps::open(&db)?, + maps: maps::open(&db)?, })) } #[inline] - #[must_use] - pub fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } + pub fn get(&self, name: &str) -> Result<&Arc<Map>> { + self.maps + .get(name) + .ok_or_else(|| err!(Request(NotFound("column not found")))) + } #[inline] - #[must_use] - pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } + pub fn iter(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + Send + '_ { self.maps.iter() } + + #[inline] + pub fn keys(&self) -> impl Iterator<Item = &MapsKey> + Send + '_ { self.maps.keys() } #[inline] #[must_use] - pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } + pub fn is_read_only(&self) -> bool { self.db.is_read_only() } #[inline] - pub fn iter_maps(&self) -> impl Iterator<Item = (&MapsKey, &MapsVal)> + '_ { self.map.iter() } + #[must_use] + pub fn is_secondary(&self) -> bool { self.db.is_secondary() } } impl Index<&str> for Database { type Output = Arc<Map>; fn index(&self, name: &str) -> &Self::Output { - self.map + self.maps .get(name) .expect("column in database does not exist") } diff --git a/src/database/de.rs b/src/database/de.rs new file mode 100644 index 0000000000000000000000000000000000000000..f8a038ef852a0cfe7c187e151fa012ac030fb691 --- /dev/null +++ b/src/database/de.rs @@ -0,0 +1,377 @@ +use conduit::{checked, debug::DebugInspect, err, utils::string, Error, Result}; +use serde::{ + de, + de::{DeserializeSeed, Visitor}, + Deserialize, +}; + +use crate::util::unhandled; + +/// Deserialize into T from buffer. +pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result<T> +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer { + buf, + pos: 0, + seq: false, + }; + + T::deserialize(&mut deserializer).debug_inspect(|_| { + deserializer + .finished() + .expect("deserialization failed to consume trailing bytes"); + }) +} + +/// Deserialization state. +pub(crate) struct Deserializer<'de> { + buf: &'de [u8], + pos: usize, + seq: bool, +} + +/// Directive to ignore a record. This type can be used to skip deserialization +/// until the next separator is found. +#[derive(Debug, Deserialize)] +pub struct Ignore; + +/// Directive to ignore all remaining records. This can be used in a sequence to +/// ignore the rest of the sequence. +#[derive(Debug, Deserialize)] +pub struct IgnoreAll; + +impl<'de> Deserializer<'de> { + const SEP: u8 = crate::ser::SEP; + + /// Determine if the input was fully consumed and error if bytes remaining. + /// This is intended for debug assertions; not optimized for parsing logic. + fn finished(&self) -> Result<()> { + let pos = self.pos; + let len = self.buf.len(); + let parsed = &self.buf[0..pos]; + let unparsed = &self.buf[pos..]; + let remain = checked!(len - pos)?; + let trailing_sep = remain == 1 && unparsed[0] == Self::SEP; + (remain == 0 || trailing_sep) + .then_some(()) + .ok_or(err!(SerdeDe( + "{remain} trailing of {len} bytes not deserialized.\n{parsed:?}\n{unparsed:?}", + ))) + } + + /// Called at the start of arrays and tuples + #[inline] + fn sequence_start(&mut self) { + debug_assert!(!self.seq, "Nested sequences are not handled at this time"); + self.seq = true; + } + + /// Consume the current record to ignore it. Inside a sequence the next + /// record is skipped but at the top-level all records are skipped such that + /// deserialization completes with self.finished() == Ok. + #[inline] + fn record_ignore(&mut self) { + if self.seq { + self.record_next(); + } else { + self.record_ignore_all(); + } + } + + /// Consume the current and all remaining records to ignore them. Similar to + /// Ignore at the top-level, but it can be provided in a sequence to Ignore + /// all remaining elements. + #[inline] + fn record_ignore_all(&mut self) { self.record_trail(); } + + /// Consume the current record. The position pointer is moved to the start + /// of the next record. Slice of the current record is returned. + #[inline] + fn record_next(&mut self) -> &'de [u8] { + self.buf[self.pos..] + .split(|b| *b == Deserializer::SEP) + .inspect(|record| self.inc_pos(record.len())) + .next() + .expect("remainder of buf even if SEP was not found") + } + + /// Peek at the first byte of the current record. If all records were + /// consumed None is returned instead. + #[inline] + fn record_peek_byte(&self) -> Option<u8> { + let started = self.pos != 0; + let buf = &self.buf[self.pos..]; + debug_assert!( + !started || buf[0] == Self::SEP, + "Missing expected record separator at current position" + ); + + buf.get::<usize>(started.into()).copied() + } + + /// Consume the record separator such that the position cleanly points to + /// the start of the next record. (Case for some sequences) + #[inline] + fn record_start(&mut self) { + let started = self.pos != 0; + debug_assert!( + !started || self.buf[self.pos] == Self::SEP, + "Missing expected record separator at current position" + ); + + self.inc_pos(started.into()); + } + + /// Consume all remaining bytes, which may include record separators, + /// returning a raw slice. + #[inline] + fn record_trail(&mut self) -> &'de [u8] { + let record = &self.buf[self.pos..]; + self.inc_pos(record.len()); + record + } + + /// Increment the position pointer. + #[inline] + fn inc_pos(&mut self, n: usize) { + self.pos = self.pos.saturating_add(n); + debug_assert!(self.pos <= self.buf.len(), "pos out of range"); + } +} + +impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.sequence_start(); + visitor.visit_seq(self) + } + + fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.sequence_start(); + visitor.visit_seq(self) + } + + fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.sequence_start(); + visitor.visit_seq(self) + } + + fn deserialize_map<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_map(visitor).map_err(Into::into) + } + + fn deserialize_struct<V>(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + let input = self.record_next(); + let mut d = serde_json::Deserializer::from_slice(input); + d.deserialize_struct(name, fields, visitor) + .map_err(Into::into) + } + + fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match name { + "Ignore" => self.record_ignore(), + "IgnoreAll" => self.record_ignore_all(), + _ => unhandled!("Unrecognized deserialization Directive {name:?}"), + }; + + visitor.visit_unit() + } + + fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match name { + "$serde_json::private::RawValue" => visitor.visit_map(self), + _ => visitor.visit_newtype_struct(self), + } + } + + fn deserialize_enum<V>( + self, _name: &'static str, _variants: &'static [&'static str], _visitor: V, + ) -> Result<V::Value> + where + V: Visitor<'de>, + { + unhandled!("deserialize Enum not implemented") + } + + fn deserialize_option<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize Option not implemented") + } + + fn deserialize_bool<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize bool not implemented") + } + + fn deserialize_i8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize i8 not implemented") + } + + fn deserialize_i16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize i16 not implemented") + } + + fn deserialize_i32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize i32 not implemented") + } + + fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let bytes: [u8; size_of::<i64>()] = self.buf[self.pos..].try_into()?; + self.inc_pos(size_of::<i64>()); + visitor.visit_i64(i64::from_be_bytes(bytes)) + } + + fn deserialize_u8<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize u8 not implemented; try dereferencing the Handle for [u8] access instead") + } + + fn deserialize_u16<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize u16 not implemented") + } + + fn deserialize_u32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize u32 not implemented") + } + + fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let bytes: [u8; size_of::<u64>()] = self.buf[self.pos..].try_into()?; + self.inc_pos(size_of::<u64>()); + visitor.visit_u64(u64::from_be_bytes(bytes)) + } + + fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize f32 not implemented") + } + + fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize f64 not implemented") + } + + fn deserialize_char<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize char not implemented") + } + + fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let input = self.record_next(); + let out = deserialize_str(input)?; + visitor.visit_borrowed_str(out) + } + + fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let input = self.record_next(); + let out = string::string_from_bytes(input)?; + visitor.visit_string(out) + } + + fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let input = self.record_trail(); + visitor.visit_borrowed_bytes(input) + } + + fn deserialize_byte_buf<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize Byte Buf not implemented") + } + + fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize Unit not implemented") + } + + // this only used for $serde_json::private::RawValue at this time; see MapAccess + fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + let input = "$serde_json::private::RawValue"; + visitor.visit_borrowed_str(input) + } + + fn deserialize_ignored_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> { + unhandled!("deserialize Ignored Any not implemented") + } + + fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> { + debug_assert_eq!( + conduit::debug::type_name::<V>(), + "serde_json::value::de::<impl serde::de::Deserialize for \ + serde_json::value::Value>::deserialize::ValueVisitor", + "deserialize_any: type not expected" + ); + + match self.record_peek_byte() { + Some(b'{') => self.deserialize_map(visitor), + _ => self.deserialize_str(visitor), + } + } +} + +impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: DeserializeSeed<'de>, + { + if self.pos >= self.buf.len() { + return Ok(None); + } + + self.record_start(); + seed.deserialize(&mut **self).map(Some) + } +} + +// this only used for $serde_json::private::RawValue at this time. our db +// schema doesn't have its own map format; we use json for that anyway +impl<'a, 'de: 'a> de::MapAccess<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>> + where + K: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self).map(Some) + } + + fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value> + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut **self) + } +} + +// activate when stable; too soon now +//#[cfg(debug_assertions)] +#[inline] +fn deserialize_str(input: &[u8]) -> Result<&str> { string::str_from_bytes(input) } + +//#[cfg(not(debug_assertions))] +#[cfg(disable)] +#[inline] +fn deserialize_str(input: &[u8]) -> Result<&str> { + // SAFETY: Strings were written by the serializer to the database. Assuming no + // database corruption, the string will be valid. Database corruption is + // detected via rocksdb checksums. + unsafe { std::str::from_utf8_unchecked(input) } +} diff --git a/src/database/deserialized.rs b/src/database/deserialized.rs new file mode 100644 index 0000000000000000000000000000000000000000..a59b2ce5479bcd57b9e36dea7b72fc087eeffef3 --- /dev/null +++ b/src/database/deserialized.rs @@ -0,0 +1,20 @@ +use std::convert::identity; + +use conduit::Result; +use serde::Deserialize; + +pub trait Deserialized { + fn map_de<T, U, F>(self, f: F) -> Result<U> + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>; + + #[inline] + fn deserialized<T>(self) -> Result<T> + where + T: for<'de> Deserialize<'de>, + Self: Sized, + { + self.map_de(identity::<T>) + } +} diff --git a/src/database/engine.rs b/src/database/engine.rs index 3850c1d3f2bbd0aae529a06df1af80458ad37203..1fa53b01210de5f2d453934432eab42bb983e8b3 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -10,13 +10,14 @@ use rocksdb::{ backup::{BackupEngine, BackupEngineOptions}, perf::get_memory_usage_stats, - AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, + AsColumnFamilyRef, BoundColumnFamily, Cache, ColumnFamilyDescriptor, DBCommon, DBWithThreadMode, Env, LogLevel, MultiThreaded, Options, }; use crate::{ opts::{cf_options, db_options}, or_else, result, + util::map_err, }; pub struct Engine { @@ -28,6 +29,8 @@ pub struct Engine { cfs: Mutex<BTreeSet<String>>, pub(crate) db: Db, corks: AtomicU32, + pub(super) read_only: bool, + pub(super) secondary: bool, } pub(crate) type Db = DBWithThreadMode<MultiThreaded>; @@ -80,10 +83,13 @@ pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { .collect::<Vec<_>>(); debug!("Opening database..."); + let path = &config.database_path; let res = if config.rocksdb_read_only { - Db::open_cf_for_read_only(&db_opts, &config.database_path, cfs.clone(), false) + Db::open_cf_descriptors_read_only(&db_opts, path, cfds, false) + } else if config.rocksdb_secondary { + Db::open_cf_descriptors_as_secondary(&db_opts, path, path, cfds) } else { - Db::open_cf_descriptors(&db_opts, &config.database_path, cfds) + Db::open_cf_descriptors(&db_opts, path, cfds) }; let db = res.or_else(or_else)?; @@ -103,10 +109,12 @@ pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { cfs: Mutex::new(cfs), db, corks: AtomicU32::new(0), + read_only: config.rocksdb_read_only, + secondary: config.rocksdb_secondary, })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "trace")] pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { @@ -176,19 +184,20 @@ pub fn cleanup(&self) -> Result<()> { } #[tracing::instrument(skip(self))] - pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { + pub fn backup(&self) -> Result { let config = &self.server.config; let path = config.database_backup_path.as_ref(); if path.is_none() || path.is_some_and(|path| path.as_os_str().is_empty()) { return Ok(()); } - let options = BackupEngineOptions::new(path.expect("valid database backup path"))?; - let mut engine = BackupEngine::open(&options, &self.env)?; + let options = BackupEngineOptions::new(path.expect("valid database backup path")).map_err(map_err)?; + let mut engine = BackupEngine::open(&options, &self.env).map_err(map_err)?; if config.database_backups_to_keep > 0 { - if let Err(e) = engine.create_new_backup_flush(&self.db, true) { - return Err(Box::new(e)); - } + let flush = !self.is_read_only(); + engine + .create_new_backup_flush(&self.db, flush) + .map_err(map_err)?; let engine_info = engine.get_backup_info(); let info = &engine_info.last().expect("backup engine info is not empty"); @@ -267,6 +276,14 @@ pub(crate) fn property(&self, cf: &impl AsColumnFamilyRef, name: &str) -> Result result(self.db.property_value_cf(cf, name)) .and_then(|val| val.map_or_else(|| Err!("Property {name:?} not found."), Ok)) } + + #[inline] + #[must_use] + pub fn is_read_only(&self) -> bool { self.secondary || self.read_only } + + #[inline] + #[must_use] + pub fn is_secondary(&self) -> bool { self.secondary } } pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { @@ -279,6 +296,21 @@ pub(crate) fn repair(db_opts: &Options, path: &PathBuf) -> Result<()> { Ok(()) } +#[tracing::instrument(skip_all, name = "rocksdb")] +pub(crate) fn handle_log(level: LogLevel, msg: &str) { + let msg = msg.trim(); + if msg.starts_with("Options") { + return; + } + + match level { + LogLevel::Header | LogLevel::Debug => debug!("{msg}"), + LogLevel::Error | LogLevel::Fatal => error!("{msg}"), + LogLevel::Info => debug!("{msg}"), + LogLevel::Warn => warn!("{msg}"), + }; +} + impl Drop for Engine { #[cold] fn drop(&mut self) { diff --git a/src/database/handle.rs b/src/database/handle.rs index 0b45a75f07f54200dc13df5cd90d1bbba5e3fc22..daee224d4742096535bea59903a1b48276bfe748 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -1,6 +1,10 @@ -use std::ops::Deref; +use std::{fmt, fmt::Debug, ops::Deref}; +use conduit::Result; use rocksdb::DBPinnableSlice; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::{keyval::deserialize_val, Deserialized, Slice}; pub struct Handle<'a> { val: DBPinnableSlice<'a>, @@ -14,14 +18,67 @@ fn from(val: DBPinnableSlice<'a>) -> Self { } } +impl Debug for Handle<'_> { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { + let val: &Slice = self; + let ptr = val.as_ptr(); + let len = val.len(); + write!(out, "Handle {{val: {{ptr: {ptr:?}, len: {len}}}}}") + } +} + +impl Serialize for Handle<'_> { + #[inline] + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let bytes: &Slice = self; + serializer.serialize_bytes(bytes) + } +} + +impl Deserialized for Result<Handle<'_>> { + #[inline] + fn map_de<T, U, F>(self, f: F) -> Result<U> + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self?.map_de(f) + } +} + +impl<'a> Deserialized for Result<&'a Handle<'a>> { + #[inline] + fn map_de<T, U, F>(self, f: F) -> Result<U> + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + self.and_then(|handle| handle.map_de(f)) + } +} + +impl<'a> Deserialized for &'a Handle<'a> { + fn map_de<T, U, F>(self, f: F) -> Result<U> + where + F: FnOnce(T) -> U, + T: for<'de> Deserialize<'de>, + { + deserialize_val(self.as_ref()).map(f) + } +} + +impl From<Handle<'_>> for Vec<u8> { + fn from(handle: Handle<'_>) -> Self { handle.deref().to_vec() } +} + impl Deref for Handle<'_> { - type Target = [u8]; + type Target = Slice; #[inline] fn deref(&self) -> &Self::Target { &self.val } } -impl AsRef<[u8]> for Handle<'_> { +impl AsRef<Slice> for Handle<'_> { #[inline] - fn as_ref(&self) -> &[u8] { &self.val } + fn as_ref(&self) -> &Slice { &self.val } } diff --git a/src/database/iter.rs b/src/database/iter.rs deleted file mode 100644 index 4845e97739511bebea07ac7ed5390da0cf284a5b..0000000000000000000000000000000000000000 --- a/src/database/iter.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::{iter::FusedIterator, sync::Arc}; - -use conduit::Result; -use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, Direction, IteratorMode, ReadOptions}; - -use crate::{ - engine::Db, - result, - slice::{OwnedKeyVal, OwnedKeyValPair}, - Engine, -}; - -type Cursor<'cursor> = DBRawIteratorWithThreadMode<'cursor, Db>; - -struct State<'cursor> { - cursor: Cursor<'cursor>, - direction: Direction, - valid: bool, - init: bool, -} - -impl<'cursor> State<'cursor> { - pub(crate) fn new( - db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - let mut cursor = db.db.raw_iterator_cf_opt(&**cf, opts); - let direction = into_direction(mode); - let valid = seek_init(&mut cursor, mode); - Self { - cursor, - direction, - valid, - init: true, - } - } -} - -pub struct Iter<'cursor> { - state: State<'cursor>, -} - -impl<'cursor> Iter<'cursor> { - pub(crate) fn new( - db: &'cursor Arc<Engine>, cf: &'cursor Arc<ColumnFamily>, opts: ReadOptions, mode: &IteratorMode<'_>, - ) -> Self { - Self { - state: State::new(db, cf, opts, mode), - } - } -} - -impl Iterator for Iter<'_> { - type Item = OwnedKeyValPair; - - fn next(&mut self) -> Option<Self::Item> { - if !self.state.init && self.state.valid { - seek_next(&mut self.state.cursor, self.state.direction); - } else if self.state.init { - self.state.init = false; - } - - self.state - .cursor - .item() - .map(OwnedKeyVal::from) - .map(OwnedKeyVal::to_tuple) - .or_else(|| { - when_invalid(&mut self.state).expect("iterator invalidated due to error"); - None - }) - } -} - -impl FusedIterator for Iter<'_> {} - -fn when_invalid(state: &mut State<'_>) -> Result<()> { - state.valid = false; - result(state.cursor.status()) -} - -fn seek_next(cursor: &mut Cursor<'_>, direction: Direction) { - match direction { - Direction::Forward => cursor.next(), - Direction::Reverse => cursor.prev(), - } -} - -fn seek_init(cursor: &mut Cursor<'_>, mode: &IteratorMode<'_>) -> bool { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start => cursor.seek_to_first(), - End => cursor.seek_to_last(), - From(key, Forward) => cursor.seek(key), - From(key, Reverse) => cursor.seek_for_prev(key), - }; - - cursor.valid() -} - -fn into_direction(mode: &IteratorMode<'_>) -> Direction { - use Direction::{Forward, Reverse}; - use IteratorMode::{End, From, Start}; - - match mode { - Start | From(_, Forward) => Forward, - End | From(_, Reverse) => Reverse, - } -} diff --git a/src/database/keyval.rs b/src/database/keyval.rs new file mode 100644 index 0000000000000000000000000000000000000000..a288f1842ec83ff67b8b5a3b0a32491c566f5c11 --- /dev/null +++ b/src/database/keyval.rs @@ -0,0 +1,75 @@ +use conduit::Result; +use serde::Deserialize; + +use crate::de; + +pub type KeyVal<'a, K = &'a Slice, V = &'a Slice> = (Key<'a, K>, Val<'a, V>); +pub type Key<'a, T = &'a Slice> = T; +pub type Val<'a, T = &'a Slice> = T; + +pub type Slice = [u8]; + +#[inline] +pub(crate) fn _expect_deserialize<'a, K, V>(kv: Result<KeyVal<'a>>) -> KeyVal<'a, K, V> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + result_deserialize(kv).expect("failed to deserialize result key/val") +} + +#[inline] +pub(crate) fn _expect_deserialize_key<'a, K>(key: Result<Key<'a>>) -> Key<'a, K> +where + K: Deserialize<'a>, +{ + result_deserialize_key(key).expect("failed to deserialize result key") +} + +#[inline] +pub(crate) fn result_deserialize<'a, K, V>(kv: Result<KeyVal<'a>>) -> Result<KeyVal<'a, K, V>> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + deserialize(kv?) +} + +#[inline] +pub(crate) fn result_deserialize_key<'a, K>(key: Result<Key<'a>>) -> Result<Key<'a, K>> +where + K: Deserialize<'a>, +{ + deserialize_key(key?) +} + +#[inline] +pub(crate) fn deserialize<'a, K, V>(kv: KeyVal<'a>) -> Result<KeyVal<'a, K, V>> +where + K: Deserialize<'a>, + V: Deserialize<'a>, +{ + Ok((deserialize_key::<K>(kv.0)?, deserialize_val::<V>(kv.1)?)) +} + +#[inline] +pub(crate) fn deserialize_key<'a, K>(key: Key<'a>) -> Result<Key<'a, K>> +where + K: Deserialize<'a>, +{ + de::from_slice::<K>(key) +} + +#[inline] +pub(crate) fn deserialize_val<'a, V>(val: Val<'a>) -> Result<Val<'a, V>> +where + V: Deserialize<'a>, +{ + de::from_slice::<V>(val) +} + +#[inline] +pub fn key<K, V>(kv: KeyVal<'_, K, V>) -> Key<'_, K> { kv.0 } + +#[inline] +pub fn val<K, V>(kv: KeyVal<'_, K, V>) -> Val<'_, V> { kv.1 } diff --git a/src/database/map.rs b/src/database/map.rs index ddae8c8136f60cb3b0c0e9d02262c22cd240c092..d6b8bf38ce94a0b2ed4175e0ed143e22f1afc848 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -1,16 +1,35 @@ -use std::{ffi::CStr, future::Future, mem::size_of, pin::Pin, sync::Arc}; - -use conduit::{utils, Result}; -use rocksdb::{ - AsColumnFamilyRef, ColumnFamily, Direction, IteratorMode, ReadOptions, WriteBatchWithTransaction, WriteOptions, +mod contains; +mod count; +mod get; +mod insert; +mod keys; +mod keys_from; +mod keys_prefix; +mod remove; +mod rev_keys; +mod rev_keys_from; +mod rev_keys_prefix; +mod rev_stream; +mod rev_stream_from; +mod rev_stream_prefix; +mod stream; +mod stream_from; +mod stream_prefix; + +use std::{ + convert::AsRef, + ffi::CStr, + fmt, + fmt::{Debug, Display}, + future::Future, + pin::Pin, + sync::Arc, }; -use crate::{ - or_else, result, - slice::{Byte, Key, KeyVal, OwnedKey, OwnedKeyValPair, OwnedVal, Val}, - watchers::Watchers, - Engine, Handle, Iter, -}; +use conduit::Result; +use rocksdb::{AsColumnFamilyRef, ColumnFamily, ReadOptions, WriteOptions}; + +use crate::{watchers::Watchers, Engine}; pub struct Map { name: String, @@ -21,8 +40,6 @@ pub struct Map { read_options: ReadOptions, } -type OwnedKeyValPairIter<'a> = Box<dyn Iterator<Item = OwnedKeyValPair> + Send + 'a>; - impl Map { pub(crate) fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<Self>> { Ok(Arc::new(Self { @@ -35,162 +52,18 @@ pub(crate) fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<Self>> { })) } - pub fn get(&self, key: &Key) -> Result<Option<Handle<'_>>> { - let read_options = &self.read_options; - let res = self.db.db.get_pinned_cf_opt(&self.cf(), key, read_options); - - Ok(result(res)?.map(Handle::from)) - } - - pub fn multi_get(&self, keys: &[&Key]) -> Result<Vec<Option<OwnedVal>>> { - // Optimization can be `true` if key vector is pre-sorted **by the column - // comparator**. - const SORTED: bool = false; - - let mut ret: Vec<Option<OwnedKey>> = Vec::with_capacity(keys.len()); - let read_options = &self.read_options; - for res in self - .db - .db - .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) - { - match res { - Ok(Some(res)) => ret.push(Some((*res).to_vec())), - Ok(None) => ret.push(None), - Err(e) => return or_else(e), - } - } - - Ok(ret) - } - - pub fn insert(&self, key: &Key, value: &Val) -> Result<()> { - let write_options = &self.write_options; - self.db - .db - .put_cf_opt(&self.cf(), key, value, write_options) - .or_else(or_else)?; - - if !self.db.corked() { - self.db.flush()?; - } - - self.watchers.wake(key); - - Ok(()) - } - - pub fn insert_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator<Item = KeyVal<'a>>, - { - let mut batch = WriteBatchWithTransaction::<false>::default(); - for KeyVal(key, value) in iter { - batch.put_cf(&self.cf(), key, value); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove(&self, key: &Key) -> Result<()> { - let write_options = &self.write_options; - let res = self.db.db.delete_cf_opt(&self.cf(), key, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn remove_batch<'a, I>(&'a self, iter: I) -> Result<()> - where - I: Iterator<Item = &'a Key>, - { - let mut batch = WriteBatchWithTransaction::<false>::default(); - for key in iter { - batch.delete_cf(&self.cf(), key); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn iter(&self) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::Start; - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn iter_from(&self, from: &Key, reverse: bool) -> OwnedKeyValPairIter<'_> { - let direction = if reverse { - Direction::Reverse - } else { - Direction::Forward - }; - let mode = IteratorMode::From(from, direction); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode)) - } - - pub fn scan_prefix(&self, prefix: OwnedKey) -> OwnedKeyValPairIter<'_> { - let mode = IteratorMode::From(&prefix, Direction::Forward); - let read_options = read_options_default(); - Box::new(Iter::new(&self.db, &self.cf, read_options, &mode).take_while(move |(k, _)| k.starts_with(&prefix))) - } - - pub fn increment(&self, key: &Key) -> Result<[Byte; size_of::<u64>()]> { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - self.insert(key, &new)?; - - if !self.db.corked() { - self.db.flush()?; - } - - Ok(new) - } - - pub fn increment_batch<'a, I>(&'a self, iter: I) -> Result<()> + #[inline] + pub fn watch_prefix<'a, K>(&'a self, prefix: &K) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> where - I: Iterator<Item = &'a Key>, + K: AsRef<[u8]> + ?Sized + Debug, { - let mut batch = WriteBatchWithTransaction::<false>::default(); - for key in iter { - let old = self.get(key)?; - let new = utils::increment(old.as_deref()); - batch.put_cf(&self.cf(), key, new); - } - - let write_options = &self.write_options; - let res = self.db.db.write_opt(batch, write_options); - - if !self.db.corked() { - self.db.flush()?; - } - - result(res) - } - - pub fn watch_prefix<'a>(&'a self, prefix: &Key) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { - self.watchers.watch(prefix) + self.watchers.watch(prefix.as_ref()) } + #[inline] pub fn property_integer(&self, name: &CStr) -> Result<u64> { self.db.property_integer(&self.cf(), name) } + #[inline] pub fn property(&self, name: &str) -> Result<String> { self.db.property(&self.cf(), name) } #[inline] @@ -199,12 +72,12 @@ pub fn name(&self) -> &str { &self.name } fn cf(&self) -> impl AsColumnFamilyRef + '_ { &*self.cf } } -impl<'a> IntoIterator for &'a Map { - type IntoIter = Box<dyn Iterator<Item = Self::Item> + Send + 'a>; - type Item = OwnedKeyValPair; +impl Debug for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "Map {{name: {0}}}", self.name) } +} - #[inline] - fn into_iter(self) -> Self::IntoIter { self.iter() } +impl Display for Map { + fn fmt(&self, out: &mut fmt::Formatter<'_>) -> fmt::Result { write!(out, "{0}", self.name) } } fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<ColumnFamily>> { @@ -212,10 +85,7 @@ fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<ColumnFamily>> { let bounded_ptr = Arc::into_raw(bounded_arc); let cf_ptr = bounded_ptr.cast::<ColumnFamily>(); - // SAFETY: After thorough contemplation this appears to be the best solution, - // even by a significant margin. - // - // BACKGROUND: Column family handles out of RocksDB are basic pointers and can + // SAFETY: Column family handles out of RocksDB are basic pointers and can // be invalidated: 1. when the database closes. 2. when the column is dropped or // closed. rust_rocksdb wraps this for us by storing handles in their own // `RwLock<BTreeMap>` map and returning an Arc<BoundColumnFamily<'_>>` to diff --git a/src/database/map/contains.rs b/src/database/map/contains.rs new file mode 100644 index 0000000000000000000000000000000000000000..a98fe7c539cf8fec25559684ac890620739caf1a --- /dev/null +++ b/src/database/map/contains.rs @@ -0,0 +1,88 @@ +use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; + +use arrayvec::ArrayVec; +use conduit::{implement, utils::TryFutureExtExt, Err, Result}; +use futures::future::ready; +use serde::Serialize; + +use crate::{ser, util}; + +/// Returns true if the map contains the key. +/// - key is serialized into allocated buffer +/// - harder errors may not be reported +#[implement(super::Map)] +pub fn contains<K>(&self, key: &K) -> impl Future<Output = bool> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::<u8>::with_capacity(64); + self.bcontains(key, &mut buf) +} + +/// Returns true if the map contains the key. +/// - key is serialized into stack-buffer +/// - harder errors will panic +#[implement(super::Map)] +pub fn acontains<const MAX: usize, K>(&self, key: &K) -> impl Future<Output = bool> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::<u8, MAX>::new(); + self.bcontains(key, &mut buf) +} + +/// Returns true if the map contains the key. +/// - key is serialized into provided buffer +/// - harder errors will panic +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bcontains<K, B>(&self, key: &K, buf: &mut B) -> impl Future<Output = bool> + Send +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + self.exists(key).is_ok() +} + +/// Returns Ok if the map contains the key. +/// - key is raw +#[implement(super::Map)] +pub fn exists<K>(&self, key: &K) -> impl Future<Output = Result<()>> + Send +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + ready(self.exists_blocking(key)) +} + +/// Returns Ok if the map contains the key; NotFound otherwise. Harder errors +/// may not always be reported properly. +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn exists_blocking<K>(&self, key: &K) -> Result<()> +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + if self.maybe_exists_blocking(key) + && self + .db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options) + .map_err(util::map_err)? + .is_some() + { + Ok(()) + } else { + Err!(Request(NotFound("Not found in database"))) + } +} + +#[implement(super::Map)] +fn maybe_exists_blocking<K>(&self, key: &K) -> bool +where + K: AsRef<[u8]> + ?Sized, +{ + self.db + .db + .key_may_exist_cf_opt(&self.cf(), key, &self.read_options) +} diff --git a/src/database/map/count.rs b/src/database/map/count.rs new file mode 100644 index 0000000000000000000000000000000000000000..3e92279c0f4a02d90d90a6f2664fecba86563a49 --- /dev/null +++ b/src/database/map/count.rs @@ -0,0 +1,58 @@ +use std::{fmt::Debug, future::Future}; + +use conduit::implement; +use futures::stream::StreamExt; +use serde::Serialize; + +/// Count the total number of entries in the map. +#[implement(super::Map)] +#[inline] +pub fn count(&self) -> impl Future<Output = usize> + Send + '_ { self.raw_keys().count() } + +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a structured key +#[implement(super::Map)] +#[inline] +pub fn count_from<'a, P>(&'a self, from: &P) -> impl Future<Output = usize> + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_from_raw(from).count() +} + +/// Count the number of entries in the map starting from a lower-bound. +/// +/// - From is a raw +#[implement(super::Map)] +#[inline] +pub fn raw_count_from<'a, P>(&'a self, from: &'a P) -> impl Future<Output = usize> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(from).count() +} + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is structured key +#[implement(super::Map)] +#[inline] +pub fn count_prefix<'a, P>(&'a self, prefix: &P) -> impl Future<Output = usize> + Send + 'a +where + P: Serialize + ?Sized + Debug + 'a, +{ + self.keys_prefix_raw(prefix).count() +} + +/// Count the number of entries in the map matching a prefix. +/// +/// - Prefix is raw +#[implement(super::Map)] +#[inline] +pub fn raw_count_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Future<Output = usize> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_prefix(prefix).count() +} diff --git a/src/database/map/get.rs b/src/database/map/get.rs new file mode 100644 index 0000000000000000000000000000000000000000..2f7df0318752f00fd06d008b22fba24fe789abb0 --- /dev/null +++ b/src/database/map/get.rs @@ -0,0 +1,110 @@ +use std::{convert::AsRef, fmt::Debug, future::Future, io::Write}; + +use arrayvec::ArrayVec; +use conduit::{err, implement, utils::IterStream, Result}; +use futures::{future::ready, Stream}; +use rocksdb::DBPinnableSlice; +use serde::Serialize; + +use crate::{ser, util, Handle}; + +type RocksdbResult<'a> = Result<Option<DBPinnableSlice<'a>>, rocksdb::Error>; + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into an allocated buffer to perform +/// the query. +#[implement(super::Map)] +pub fn qry<K>(&self, key: &K) -> impl Future<Output = Result<Handle<'_>>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = Vec::<u8>::with_capacity(64); + self.bqry(key, &mut buf) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a fixed-sized buffer to perform +/// the query. The maximum size is supplied as const generic parameter. +#[implement(super::Map)] +pub fn aqry<const MAX: usize, K>(&self, key: &K) -> impl Future<Output = Result<Handle<'_>>> + Send +where + K: Serialize + ?Sized + Debug, +{ + let mut buf = ArrayVec::<u8, MAX>::new(); + self.bqry(key, &mut buf) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is serialized into a user-supplied Writer. +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bqry<K, B>(&self, key: &K, buf: &mut B) -> impl Future<Output = Result<Handle<'_>>> + Send +where + K: Serialize + ?Sized + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize query key"); + self.get(key) +} + +/// Fetch a value from the database into cache, returning a reference-handle +/// asynchronously. The key is referenced directly to perform the query. +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn get<K>(&self, key: &K) -> impl Future<Output = Result<Handle<'_>>> + Send +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + ready(self.get_blocking(key)) +} + +/// Fetch a value from the database into cache, returning a reference-handle. +/// The key is referenced directly to perform the query. This is a thread- +/// blocking call. +#[implement(super::Map)] +pub fn get_blocking<K>(&self, key: &K) -> Result<Handle<'_>> +where + K: AsRef<[u8]> + ?Sized, +{ + let res = self + .db + .db + .get_pinned_cf_opt(&self.cf(), key, &self.read_options); + + into_result_handle(res) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] +pub fn get_batch<'a, I, K>(&self, keys: I) -> impl Stream<Item = Result<Handle<'_>>> +where + I: Iterator<Item = &'a K> + ExactSizeIterator + Send + Debug, + K: AsRef<[u8]> + Send + Sync + Sized + Debug + 'a, +{ + self.get_batch_blocking(keys).stream() +} + +#[implement(super::Map)] +pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator<Item = Result<Handle<'_>>> +where + I: Iterator<Item = &'a K> + ExactSizeIterator + Send, + K: AsRef<[u8]> + Sized + 'a, +{ + // Optimization can be `true` if key vector is pre-sorted **by the column + // comparator**. + const SORTED: bool = false; + + let read_options = &self.read_options; + self.db + .db + .batched_multi_get_cf_opt(&self.cf(), keys, SORTED, read_options) + .into_iter() + .map(into_result_handle) +} + +fn into_result_handle(result: RocksdbResult<'_>) -> Result<Handle<'_>> { + result + .map_err(util::map_err)? + .map(Handle::from) + .ok_or(err!(Request(NotFound("Not found in database")))) +} diff --git a/src/database/map/insert.rs b/src/database/map/insert.rs new file mode 100644 index 0000000000000000000000000000000000000000..39a0c422e6a9e924fc053109b9802a746833c2e4 --- /dev/null +++ b/src/database/map/insert.rs @@ -0,0 +1,225 @@ +//! Insert a Key+Value into the database. +//! +//! Overloads are provided for the user to choose the most efficient +//! serialization or bypass for pre=serialized (raw) inputs. + +use std::{convert::AsRef, fmt::Debug, io::Write}; + +use arrayvec::ArrayVec; +use conduit::implement; +use rocksdb::WriteBatchWithTransaction; +use serde::Serialize; + +use crate::{ser, util::or_else}; + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is serialized +#[implement(super::Map)] +pub fn put<K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = Vec::new(); + let mut val_buf = Vec::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is raw +#[implement(super::Map)] +pub fn put_raw<K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: AsRef<[u8]>, +{ + let mut key_buf = Vec::new(); + self.bput_raw(key, val, &mut key_buf); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized +#[implement(super::Map)] +pub fn raw_put<K, V>(&self, key: K, val: V) +where + K: AsRef<[u8]>, + V: Serialize, +{ + let mut val_buf = Vec::new(); + self.raw_bput(key, val, &mut val_buf); +} + +/// Insert Key/Value +/// +/// - Key is serialized +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn put_aput<const VMAX: usize, K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = Vec::new(); + let mut val_buf = ArrayVec::<u8, VMAX>::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is serialized +#[implement(super::Map)] +pub fn aput_put<const KMAX: usize, K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = ArrayVec::<u8, KMAX>::new(); + let mut val_buf = Vec::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn aput<const KMAX: usize, const VMAX: usize, K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: Serialize, +{ + let mut key_buf = ArrayVec::<u8, KMAX>::new(); + let mut val_buf = ArrayVec::<u8, VMAX>::new(); + self.bput(key, val, (&mut key_buf, &mut val_buf)); +} + +/// Insert Key/Value +/// +/// - Key is serialized to stack-buffer +/// - Val is raw +#[implement(super::Map)] +pub fn aput_raw<const KMAX: usize, K, V>(&self, key: K, val: V) +where + K: Serialize + Debug, + V: AsRef<[u8]>, +{ + let mut key_buf = ArrayVec::<u8, KMAX>::new(); + self.bput_raw(key, val, &mut key_buf); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized to stack-buffer +#[implement(super::Map)] +pub fn raw_aput<const VMAX: usize, K, V>(&self, key: K, val: V) +where + K: AsRef<[u8]>, + V: Serialize, +{ + let mut val_buf = ArrayVec::<u8, VMAX>::new(); + self.raw_bput(key, val, &mut val_buf); +} + +/// Insert Key/Value +/// +/// - Key is serialized to supplied buffer +/// - Val is serialized to supplied buffer +#[implement(super::Map)] +pub fn bput<K, V, Bk, Bv>(&self, key: K, val: V, mut buf: (Bk, Bv)) +where + K: Serialize + Debug, + V: Serialize, + Bk: Write + AsRef<[u8]>, + Bv: Write + AsRef<[u8]>, +{ + let val = ser::serialize(&mut buf.1, val).expect("failed to serialize insertion val"); + self.bput_raw(key, val, &mut buf.0); +} + +/// Insert Key/Value +/// +/// - Key is serialized to supplied buffer +/// - Val is raw +#[implement(super::Map)] +pub fn bput_raw<K, V, Bk>(&self, key: K, val: V, mut buf: Bk) +where + K: Serialize + Debug, + V: AsRef<[u8]>, + Bk: Write + AsRef<[u8]>, +{ + let key = ser::serialize(&mut buf, key).expect("failed to serialize insertion key"); + self.insert(&key, val); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is serialized to supplied buffer +#[implement(super::Map)] +pub fn raw_bput<K, V, Bv>(&self, key: K, val: V, mut buf: Bv) +where + K: AsRef<[u8]>, + V: Serialize, + Bv: Write + AsRef<[u8]>, +{ + let val = ser::serialize(&mut buf, val).expect("failed to serialize insertion val"); + self.insert(&key, val); +} + +/// Insert Key/Value +/// +/// - Key is raw +/// - Val is raw +#[implement(super::Map)] +#[tracing::instrument(skip_all, fields(%self), level = "trace")] +pub fn insert<K, V>(&self, key: &K, val: V) +where + K: AsRef<[u8]> + ?Sized, + V: AsRef<[u8]>, +{ + let write_options = &self.write_options; + self.db + .db + .put_cf_opt(&self.cf(), key, val, write_options) + .or_else(or_else) + .expect("database insert error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } + + self.watchers.wake(key.as_ref()); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, iter), fields(%self), level = "trace")] +pub fn insert_batch<'a, I, K, V>(&'a self, iter: I) +where + I: Iterator<Item = &'a (K, V)> + Send + Debug, + K: AsRef<[u8]> + Sized + Debug + 'a, + V: AsRef<[u8]> + Sized + 'a, +{ + let mut batch = WriteBatchWithTransaction::<false>::default(); + for (key, val) in iter { + batch.put_cf(&self.cf(), key.as_ref(), val.as_ref()); + } + + let write_options = &self.write_options; + self.db + .db + .write_opt(batch, write_options) + .or_else(or_else) + .expect("database insert batch error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/database/map/keys.rs b/src/database/map/keys.rs new file mode 100644 index 0000000000000000000000000000000000000000..2396494c419decddc4444c7d7a730777862ac2bb --- /dev/null +++ b/src/database/map/keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys<'a, K>(&'a self) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + K: Deserialize<'a> + Send, +{ + self.raw_keys().map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys(&self) -> impl Stream<Item = Result<Key<'_>>> + Send { + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/keys_from.rs b/src/database/map/keys_from.rs new file mode 100644 index 0000000000000000000000000000000000000000..4eb3b12e567b31b94fd4b0e003626aec6521e5bc --- /dev/null +++ b/src/database/map/keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_from_raw(from) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_from_raw<P>(&self, from: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_from<'a, K, P>(&'a self, from: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.raw_keys_from(from) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_from<P>(&self, from: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Keys::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/keys_prefix.rs b/src/database/map/keys_prefix.rs new file mode 100644 index 0000000000000000000000000000000000000000..0ff755f354add016e2866bc49569054b54c34fa0 --- /dev/null +++ b/src/database/map/keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.keys_prefix_raw(prefix) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_prefix_raw<P>(&self, prefix: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn keys_raw_prefix<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<Key<'_, K>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<Key<'_>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/remove.rs b/src/database/map/remove.rs new file mode 100644 index 0000000000000000000000000000000000000000..42eaa477dccb9487b88f7691869fc3af12c5479e --- /dev/null +++ b/src/database/map/remove.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug, io::Write}; + +use arrayvec::ArrayVec; +use conduit::implement; +use serde::Serialize; + +use crate::{ser, util::or_else}; + +#[implement(super::Map)] +pub fn del<K>(&self, key: K) +where + K: Serialize + Debug, +{ + let mut buf = Vec::<u8>::with_capacity(64); + self.bdel(key, &mut buf); +} + +#[implement(super::Map)] +pub fn adel<const MAX: usize, K>(&self, key: K) +where + K: Serialize + Debug, +{ + let mut buf = ArrayVec::<u8, MAX>::new(); + self.bdel(key, &mut buf); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, buf), fields(%self), level = "trace")] +pub fn bdel<K, B>(&self, key: K, buf: &mut B) +where + K: Serialize + Debug, + B: Write + AsRef<[u8]>, +{ + let key = ser::serialize(buf, key).expect("failed to serialize deletion key"); + self.remove(key); +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self, key), fields(%self), level = "trace")] +pub fn remove<K>(&self, key: &K) +where + K: AsRef<[u8]> + ?Sized + Debug, +{ + let write_options = &self.write_options; + self.db + .db + .delete_cf_opt(&self.cf(), key, write_options) + .or_else(or_else) + .expect("database remove error"); + + if !self.db.corked() { + self.db.flush().expect("database flush error"); + } +} diff --git a/src/database/map/rev_keys.rs b/src/database/map/rev_keys.rs new file mode 100644 index 0000000000000000000000000000000000000000..449ccfff39a58e99dd531a829eeeb072be758a24 --- /dev/null +++ b/src/database/map/rev_keys.rs @@ -0,0 +1,21 @@ +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::Key, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys<'a, K>(&'a self) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys().map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys(&self) -> impl Stream<Item = Result<Key<'_>>> + Send { + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_keys_from.rs b/src/database/map/rev_keys_from.rs new file mode 100644 index 0000000000000000000000000000000000000000..b142718ced4ab6df7ebd1fc4393521bc4859d359 --- /dev/null +++ b/src/database/map/rev_keys_from.rs @@ -0,0 +1,49 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser, stream}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from<'a, K, P>(&'a self, from: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_from_raw(from) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_from_raw<P>(&self, from: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_from<'a, K, P>(&'a self, from: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, +{ + self.rev_raw_keys_from(from) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_from<P>(&self, from: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::KeysRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_keys_prefix.rs b/src/database/map/rev_keys_prefix.rs new file mode 100644 index 0000000000000000000000000000000000000000..5297cecf9efea5ef0838877920764214816f806d --- /dev/null +++ b/src/database/map/rev_keys_prefix.rs @@ -0,0 +1,54 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::Key, ser}; + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix<'a, K, P>(&'a self, prefix: &P) -> impl Stream<Item = Result<Key<'_, K>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, +{ + self.rev_keys_prefix_raw(prefix) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_prefix_raw<P>(&self, prefix: &P) -> impl Stream<Item = Result<Key<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_keys_from(&key) + .try_take_while(move |k: &Key<'_>| future::ok(k.starts_with(&key))) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_keys_raw_prefix<'a, K, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<Key<'_, K>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_keys_prefix(prefix) + .map(keyval::result_deserialize_key::<K>) +} + +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_keys_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<Key<'_>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_keys_from(prefix) + .try_take_while(|k: &Key<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/rev_stream.rs b/src/database/map/rev_stream.rs new file mode 100644 index 0000000000000000000000000000000000000000..de22fd5ce0ace90289a26f0f098f8c6fbbbef9ee --- /dev/null +++ b/src/database/map/rev_stream.rs @@ -0,0 +1,29 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream<'a, K, V>(&'a self) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream() + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map from the end. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream(&self) -> impl Stream<Item = Result<KeyVal<'_>>> + Send { + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/rev_stream_from.rs b/src/database/map/rev_stream_from.rs new file mode 100644 index 0000000000000000000000000000000000000000..78318a7fef2ae976cd03512f8354f49b0861d4e6 --- /dev/null +++ b/src/database/map/rev_stream_from.rs @@ -0,0 +1,67 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_stream_from_raw(from) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_from_raw<P>(&self, from: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_raw_stream_from(from) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map starting from upper-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_from<P>(&self, from: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::ItemsRev::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/rev_stream_prefix.rs b/src/database/map/rev_stream_prefix.rs new file mode 100644 index 0000000000000000000000000000000000000000..601c3298c575eb196a7cade77c29000262328055 --- /dev/null +++ b/src/database/map/rev_stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.rev_stream_prefix_raw(prefix) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_prefix_raw<P>(&self, prefix: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.rev_raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_stream_raw_prefix<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.rev_raw_stream_prefix(prefix) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn rev_raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.rev_raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/map/stream.rs b/src/database/map/stream.rs new file mode 100644 index 0000000000000000000000000000000000000000..dfbea072991e3a3648103f9e0cc583aec19d5872 --- /dev/null +++ b/src/database/map/stream.rs @@ -0,0 +1,28 @@ +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::Deserialize; + +use crate::{keyval, keyval::KeyVal, stream}; + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream<'a, K, V>(&'a self) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream().map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map from the beginning. +/// +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream(&self) -> impl Stream<Item = Result<KeyVal<'_>>> + Send { + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, None) +} diff --git a/src/database/map/stream_from.rs b/src/database/map/stream_from.rs new file mode 100644 index 0000000000000000000000000000000000000000..0d3bb1e10c8497e3607b04a6c3f2b8702d7eb82d --- /dev/null +++ b/src/database/map/stream_from.rs @@ -0,0 +1,67 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser, stream}; + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.stream_from_raw(from) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_from_raw<P>(&self, from: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(from).expect("failed to serialize query key"); + self.raw_stream_from(&key) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_from<'a, K, V, P>(&'a self, from: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug + Sync, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.raw_stream_from(from) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map starting from lower-bound. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_from<P>(&self, from: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: AsRef<[u8]> + ?Sized + Debug, +{ + let opts = super::read_options_default(); + stream::Items::new(&self.db, &self.cf, opts, Some(from.as_ref())) +} diff --git a/src/database/map/stream_prefix.rs b/src/database/map/stream_prefix.rs new file mode 100644 index 0000000000000000000000000000000000000000..cab3dd0989af89e32e8438a6c54e4cf43b705a49 --- /dev/null +++ b/src/database/map/stream_prefix.rs @@ -0,0 +1,74 @@ +use std::{convert::AsRef, fmt::Debug}; + +use conduit::{implement, Result}; +use futures::{ + future, + stream::{Stream, StreamExt}, + TryStreamExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::{keyval, keyval::KeyVal, ser}; + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix<'a, K, V, P>(&'a self, prefix: &P) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send +where + P: Serialize + ?Sized + Debug, + K: Deserialize<'a> + Send, + V: Deserialize<'a> + Send, +{ + self.stream_prefix_raw(prefix) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is serialized +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_prefix_raw<P>(&self, prefix: &P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send +where + P: Serialize + ?Sized + Debug, +{ + let key = ser::serialize_to_vec(prefix).expect("failed to serialize query key"); + self.raw_stream_from(&key) + .try_take_while(move |(k, _): &KeyVal<'_>| future::ok(k.starts_with(&key))) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is deserialized +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn stream_raw_prefix<'a, K, V, P>( + &'a self, prefix: &'a P, +) -> impl Stream<Item = Result<KeyVal<'_, K, V>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, + K: Deserialize<'a> + Send + 'a, + V: Deserialize<'a> + Send + 'a, +{ + self.raw_stream_prefix(prefix) + .map(keyval::result_deserialize::<K, V>) +} + +/// Iterate key-value entries in the map where the key matches a prefix. +/// +/// - Query is raw +/// - Result is raw +#[implement(super::Map)] +#[tracing::instrument(skip(self), fields(%self), level = "trace")] +pub fn raw_stream_prefix<'a, P>(&'a self, prefix: &'a P) -> impl Stream<Item = Result<KeyVal<'_>>> + Send + 'a +where + P: AsRef<[u8]> + ?Sized + Debug + Sync + 'a, +{ + self.raw_stream_from(prefix) + .try_take_while(|(k, _): &KeyVal<'_>| future::ok(k.starts_with(prefix.as_ref()))) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6446624caa2d79d8a7843fefe5f53193b614a2d6..f09c4a71283536956ed167f3c71125ff48d60bc9 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,25 +1,36 @@ mod cork; mod database; +mod de; +mod deserialized; mod engine; mod handle; -mod iter; +pub mod keyval; mod map; pub mod maps; mod opts; -mod slice; -mod util; +mod ser; +mod stream; +mod tests; +pub(crate) mod util; mod watchers; +pub(crate) use self::{ + engine::Engine, + util::{or_else, result}, +}; + extern crate conduit_core as conduit; extern crate rust_rocksdb as rocksdb; -pub use database::Database; -pub(crate) use engine::Engine; -pub use handle::Handle; -pub use iter::Iter; -pub use map::Map; -pub use slice::{Key, KeyVal, OwnedKey, OwnedKeyVal, OwnedVal, Val}; -pub(crate) use util::{or_else, result}; +pub use self::{ + database::Database, + de::{Ignore, IgnoreAll}, + deserialized::Deserialized, + handle::Handle, + keyval::{KeyVal, Slice}, + map::Map, + ser::{serialize, serialize_to_array, serialize_to_vec, Interfix, Json, Separator, SEP}, +}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/database/opts.rs b/src/database/opts.rs index d2ad4b95c7ea63e9334b796714d917edcf2b008b..46fb4c5424ff90e601735893c242234e2455c924 100644 --- a/src/database/opts.rs +++ b/src/database/opts.rs @@ -191,6 +191,8 @@ fn set_logging_defaults(opts: &mut Options, config: &Config) { if config.rocksdb_log_stderr { opts.set_stderr_logger(rocksdb_log_level, "rocksdb"); + } else { + opts.set_callback_logger(rocksdb_log_level, &super::engine::handle_log); } } diff --git a/src/database/ser.rs b/src/database/ser.rs new file mode 100644 index 0000000000000000000000000000000000000000..961d2700b9c8e1d8ca9a558aff6c3dc421ca2cc3 --- /dev/null +++ b/src/database/ser.rs @@ -0,0 +1,342 @@ +use std::io::Write; + +use arrayvec::ArrayVec; +use conduit::{debug::type_name, err, result::DebugInspect, utils::exchange, Error, Result}; +use serde::{ser, Serialize}; + +use crate::util::unhandled; + +#[inline] +pub fn serialize_to_array<const MAX: usize, T>(val: T) -> Result<ArrayVec<u8, MAX>> +where + T: Serialize, +{ + let mut buf = ArrayVec::<u8, MAX>::new(); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub fn serialize_to_vec<T>(val: T) -> Result<Vec<u8>> +where + T: Serialize, +{ + let mut buf = Vec::with_capacity(64); + serialize(&mut buf, val)?; + + Ok(buf) +} + +#[inline] +pub fn serialize<'a, W, T>(out: &'a mut W, val: T) -> Result<&'a [u8]> +where + W: Write + AsRef<[u8]> + 'a, + T: Serialize, +{ + let mut serializer = Serializer { + out, + depth: 0, + sep: false, + fin: false, + }; + + val.serialize(&mut serializer) + .map_err(|error| err!(SerdeSer("{error}"))) + .debug_inspect(|()| { + debug_assert_eq!(serializer.depth, 0, "Serialization completed at non-zero recursion level"); + })?; + + Ok((*out).as_ref()) +} + +pub(crate) struct Serializer<'a, W: Write> { + out: &'a mut W, + depth: u32, + sep: bool, + fin: bool, +} + +/// Newtype for JSON serialization. +#[derive(Debug, Serialize)] +pub struct Json<T>(pub T); + +/// Directive to force separator serialization specifically for prefix keying +/// use. This is a quirk of the database schema and prefix iterations. +#[derive(Debug, Serialize)] +pub struct Interfix; + +/// Directive to force separator serialization. Separators are usually +/// serialized automatically. +#[derive(Debug, Serialize)] +pub struct Separator; + +/// Record separator; an intentionally invalid-utf8 byte. +pub const SEP: u8 = b'\xFF'; + +impl<W: Write> Serializer<'_, W> { + const SEP: &'static [u8] = &[SEP]; + + fn tuple_start(&mut self) { + debug_assert!(!self.sep, "Tuple start with separator set"); + self.sequence_start(); + } + + fn tuple_end(&mut self) -> Result { + self.sequence_end()?; + Ok(()) + } + + fn sequence_start(&mut self) { + debug_assert!(!self.is_finalized(), "Sequence start with finalization set"); + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_add(1)); + } + + fn sequence_end(&mut self) -> Result { + cfg!(debug_assertions).then(|| self.depth = self.depth.saturating_sub(1)); + Ok(()) + } + + fn record_start(&mut self) -> Result { + debug_assert!(!self.is_finalized(), "Starting a record after serialization finalized"); + exchange(&mut self.sep, true) + .then(|| self.separator()) + .unwrap_or(Ok(())) + } + + fn separator(&mut self) -> Result { + debug_assert!(!self.is_finalized(), "Writing a separator after serialization finalized"); + self.out.write_all(Self::SEP).map_err(Into::into) + } + + fn write(&mut self, buf: &[u8]) -> Result { self.out.write_all(buf).map_err(Into::into) } + + fn set_finalized(&mut self) { + debug_assert!(!self.is_finalized(), "Finalization already set"); + cfg!(debug_assertions).then(|| self.fin = true); + } + + fn is_finalized(&self) -> bool { self.fin } +} + +impl<W: Write> ser::Serializer for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + type SerializeMap = Self; + type SerializeSeq = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + + fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> { + self.sequence_start(); + Ok(self) + } + + fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> { + self.tuple_start(); + Ok(self) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeTupleStruct> { + self.tuple_start(); + Ok(self) + } + + fn serialize_tuple_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result<Self::SerializeTupleVariant> { + unhandled!("serialize Tuple Variant not implemented") + } + + fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> { + unhandled!("serialize Map not implemented; did you mean to use database::Json() around your serde_json::Value?") + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> { + unhandled!( + "serialize Struct not implemented at this time; did you mean to use database::Json() around your struct?" + ) + } + + fn serialize_struct_variant( + self, _name: &'static str, _idx: u32, _var: &'static str, _len: usize, + ) -> Result<Self::SerializeStructVariant> { + unhandled!("serialize Struct Variant not implemented") + } + + #[allow(clippy::needless_borrows_for_generic_args)] // buggy + fn serialize_newtype_struct<T>(self, name: &'static str, value: &T) -> Result<Self::Ok> + where + T: Serialize + ?Sized, + { + debug_assert!( + name != "Json" || type_name::<T>() != "alloc::boxed::Box<serde_json::raw::RawValue>", + "serializing a Json(RawValue); you can skip serialization instead" + ); + + match name { + "Json" => serde_json::to_writer(&mut self.out, value).map_err(Into::into), + _ => unhandled!("Unrecognized serialization Newtype {name:?}"), + } + } + + fn serialize_newtype_variant<T: Serialize + ?Sized>( + self, _name: &'static str, _idx: u32, _var: &'static str, _value: &T, + ) -> Result<Self::Ok> { + unhandled!("serialize Newtype Variant not implemented") + } + + fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok> { + match name { + "Interfix" => { + self.set_finalized(); + }, + "Separator" => { + self.separator()?; + }, + _ => unhandled!("Unrecognized serialization directive: {name:?}"), + }; + + Ok(()) + } + + fn serialize_unit_variant(self, _name: &'static str, _idx: u32, _var: &'static str) -> Result<Self::Ok> { + unhandled!("serialize Unit Variant not implemented") + } + + fn serialize_some<T: Serialize + ?Sized>(self, val: &T) -> Result<Self::Ok> { val.serialize(self) } + + fn serialize_none(self) -> Result<Self::Ok> { Ok(()) } + + fn serialize_char(self, v: char) -> Result<Self::Ok> { + let mut buf: [u8; 4] = [0; 4]; + self.serialize_str(v.encode_utf8(&mut buf)) + } + + fn serialize_str(self, v: &str) -> Result<Self::Ok> { + debug_assert!( + self.depth > 0, + "serializing string at the top-level; you can skip serialization instead" + ); + + self.serialize_bytes(v.as_bytes()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok> { + debug_assert!( + self.depth > 0, + "serializing byte array at the top-level; you can skip serialization instead" + ); + + self.write(v) + } + + fn serialize_f64(self, _v: f64) -> Result<Self::Ok> { unhandled!("serialize f64 not implemented") } + + fn serialize_f32(self, _v: f32) -> Result<Self::Ok> { unhandled!("serialize f32 not implemented") } + + fn serialize_i64(self, v: i64) -> Result<Self::Ok> { self.write(&v.to_be_bytes()) } + + fn serialize_i32(self, v: i32) -> Result<Self::Ok> { self.write(&v.to_be_bytes()) } + + fn serialize_i16(self, _v: i16) -> Result<Self::Ok> { unhandled!("serialize i16 not implemented") } + + fn serialize_i8(self, _v: i8) -> Result<Self::Ok> { unhandled!("serialize i8 not implemented") } + + fn serialize_u64(self, v: u64) -> Result<Self::Ok> { self.write(&v.to_be_bytes()) } + + fn serialize_u32(self, v: u32) -> Result<Self::Ok> { self.write(&v.to_be_bytes()) } + + fn serialize_u16(self, _v: u16) -> Result<Self::Ok> { unhandled!("serialize u16 not implemented") } + + fn serialize_u8(self, v: u8) -> Result<Self::Ok> { self.write(&[v]) } + + fn serialize_bool(self, _v: bool) -> Result<Self::Ok> { unhandled!("serialize bool not implemented") } + + fn serialize_unit(self) -> Result<Self::Ok> { unhandled!("serialize unit not implemented") } +} + +impl<W: Write> ser::SerializeSeq for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element<T: Serialize + ?Sized>(&mut self, val: &T) -> Result<Self::Ok> { val.serialize(&mut **self) } + + fn end(self) -> Result<Self::Ok> { self.sequence_end() } +} + +impl<W: Write> ser::SerializeTuple for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_element<T: Serialize + ?Sized>(&mut self, val: &T) -> Result<Self::Ok> { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result<Self::Ok> { self.tuple_end() } +} + +impl<W: Write> ser::SerializeTupleStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field<T: Serialize + ?Sized>(&mut self, val: &T) -> Result<Self::Ok> { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result<Self::Ok> { self.tuple_end() } +} + +impl<W: Write> ser::SerializeTupleVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field<T: Serialize + ?Sized>(&mut self, val: &T) -> Result<Self::Ok> { + self.record_start()?; + val.serialize(&mut **self) + } + + fn end(self) -> Result<Self::Ok> { self.tuple_end() } +} + +impl<W: Write> ser::SerializeMap for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_key<T: Serialize + ?Sized>(&mut self, _key: &T) -> Result<Self::Ok> { + unhandled!("serialize Map Key not implemented") + } + + fn serialize_value<T: Serialize + ?Sized>(&mut self, _val: &T) -> Result<Self::Ok> { + unhandled!("serialize Map Val not implemented") + } + + fn end(self) -> Result<Self::Ok> { unhandled!("serialize Map End not implemented") } +} + +impl<W: Write> ser::SerializeStruct for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field<T: Serialize + ?Sized>(&mut self, _key: &'static str, _val: &T) -> Result<Self::Ok> { + unhandled!("serialize Struct Field not implemented") + } + + fn end(self) -> Result<Self::Ok> { unhandled!("serialize Struct End not implemented") } +} + +impl<W: Write> ser::SerializeStructVariant for &mut Serializer<'_, W> { + type Error = Error; + type Ok = (); + + fn serialize_field<T: Serialize + ?Sized>(&mut self, _key: &'static str, _val: &T) -> Result<Self::Ok> { + unhandled!("serialize Struct Variant Field not implemented") + } + + fn end(self) -> Result<Self::Ok> { unhandled!("serialize Struct Variant End not implemented") } +} diff --git a/src/database/slice.rs b/src/database/slice.rs deleted file mode 100644 index 448d969d99c5a6f34753fad988fb68e217390598..0000000000000000000000000000000000000000 --- a/src/database/slice.rs +++ /dev/null @@ -1,57 +0,0 @@ -pub struct OwnedKeyVal(pub OwnedKey, pub OwnedVal); -pub(crate) type OwnedKeyValPair = (OwnedKey, OwnedVal); -pub type OwnedVal = Vec<Byte>; -pub type OwnedKey = Vec<Byte>; - -pub struct KeyVal<'item>(pub &'item Key, pub &'item Val); -pub(crate) type KeyValPair<'item> = (&'item Key, &'item Val); -pub type Val = [Byte]; -pub type Key = [Byte]; - -pub(crate) type Byte = u8; - -impl OwnedKeyVal { - #[must_use] - pub fn as_slice(&self) -> KeyVal<'_> { KeyVal(&self.0, &self.1) } - - #[must_use] - pub fn to_tuple(self) -> OwnedKeyValPair { (self.0, self.1) } -} - -impl From<OwnedKeyValPair> for OwnedKeyVal { - fn from((key, val): OwnedKeyValPair) -> Self { Self(key, val) } -} - -impl From<&KeyVal<'_>> for OwnedKeyVal { - #[inline] - fn from(slice: &KeyVal<'_>) -> Self { slice.to_owned() } -} - -impl From<KeyValPair<'_>> for OwnedKeyVal { - fn from((key, val): KeyValPair<'_>) -> Self { Self(Vec::from(key), Vec::from(val)) } -} - -impl From<OwnedKeyVal> for OwnedKeyValPair { - fn from(val: OwnedKeyVal) -> Self { val.to_tuple() } -} - -impl KeyVal<'_> { - #[inline] - #[must_use] - pub fn to_owned(&self) -> OwnedKeyVal { OwnedKeyVal::from(self) } - - #[must_use] - pub fn as_tuple(&self) -> KeyValPair<'_> { (self.0, self.1) } -} - -impl<'a> From<&'a OwnedKeyVal> for KeyVal<'a> { - fn from(owned: &'a OwnedKeyVal) -> Self { owned.as_slice() } -} - -impl<'a> From<&'a OwnedKeyValPair> for KeyVal<'a> { - fn from((key, val): &'a OwnedKeyValPair) -> Self { KeyVal(key.as_slice(), val.as_slice()) } -} - -impl<'a> From<KeyValPair<'a>> for KeyVal<'a> { - fn from((key, val): KeyValPair<'a>) -> Self { KeyVal(key, val) } -} diff --git a/src/database/stream.rs b/src/database/stream.rs new file mode 100644 index 0000000000000000000000000000000000000000..a2a72e44c50fda55c50db974a862e5065e2da5d3 --- /dev/null +++ b/src/database/stream.rs @@ -0,0 +1,125 @@ +mod items; +mod items_rev; +mod keys; +mod keys_rev; + +use std::sync::Arc; + +use conduit::{utils::exchange, Error, Result}; +use rocksdb::{ColumnFamily, DBRawIteratorWithThreadMode, ReadOptions}; + +pub(crate) use self::{items::Items, items_rev::ItemsRev, keys::Keys, keys_rev::KeysRev}; +use crate::{ + engine::Db, + keyval::{Key, KeyVal, Val}, + util::map_err, + Engine, Slice, +}; + +struct State<'a> { + inner: Inner<'a>, + seek: bool, + init: bool, +} + +trait Cursor<'a, T> { + fn state(&self) -> &State<'a>; + + fn fetch(&self) -> Option<T>; + + fn seek(&mut self); + + fn get(&self) -> Option<Result<T>> { + self.fetch() + .map(Ok) + .or_else(|| self.state().status().map(Err)) + } + + fn seek_and_get(&mut self) -> Option<Result<T>> { + self.seek(); + self.get() + } +} + +type Inner<'a> = DBRawIteratorWithThreadMode<'a, Db>; +type From<'a> = Option<Key<'a>>; + +impl<'a> State<'a> { + fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions) -> Self { + Self { + inner: db.db.raw_iterator_cf_opt(&**cf, opts), + init: true, + seek: false, + } + } + + fn init_fwd(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek(key); + self.seek = true; + } + + self + } + + fn init_rev(mut self, from: From<'_>) -> Self { + if let Some(key) = from { + self.inner.seek_for_prev(key); + self.seek = true; + } + + self + } + + #[inline] + fn seek_fwd(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.next(); + } else if !self.seek { + self.inner.seek_to_first(); + } + } + + #[inline] + fn seek_rev(&mut self) { + if !exchange(&mut self.init, false) { + self.inner.prev(); + } else if !self.seek { + self.inner.seek_to_last(); + } + } + + fn fetch_key(&self) -> Option<Key<'_>> { self.inner.key().map(Key::from) } + + fn _fetch_val(&self) -> Option<Val<'_>> { self.inner.value().map(Val::from) } + + fn fetch(&self) -> Option<KeyVal<'_>> { self.inner.item().map(KeyVal::from) } + + fn status(&self) -> Option<Error> { self.inner.status().map_err(map_err).err() } + + #[inline] + fn valid(&self) -> bool { self.inner.valid() } +} + +fn keyval_longevity<'a, 'b: 'a>(item: KeyVal<'a>) -> KeyVal<'b> { + (slice_longevity::<'a, 'b>(item.0), slice_longevity::<'a, 'b>(item.1)) +} + +fn slice_longevity<'a, 'b: 'a>(item: &'a Slice) -> &'b Slice { + // SAFETY: The lifetime of the data returned by the rocksdb cursor is only valid + // between each movement of the cursor. It is hereby unsafely extended to match + // the lifetime of the cursor itself. This is due to the limitation of the + // Stream trait where the Item is incapable of conveying a lifetime; this is due + // to GAT's being unstable during its development. This unsafety can be removed + // as soon as this limitation is addressed by an upcoming version. + // + // We have done our best to mitigate the implications of this in conjunction + // with the deserialization API such that borrows being held across movements of + // the cursor do not happen accidentally. The compiler will still error when + // values herein produced try to leave a closure passed to a StreamExt API. But + // escapes can happen if you explicitly and intentionally attempt it, and there + // will be no compiler error or warning. This is primarily the case with + // calling collect() without a preceding map(ToOwned::to_owned). A collection + // of references here is illegal, but this will not be enforced by the compiler. + unsafe { std::mem::transmute(item) } +} diff --git a/src/database/stream/items.rs b/src/database/stream/items.rs new file mode 100644 index 0000000000000000000000000000000000000000..54f8bc5c971fedb2264d361f540eb96973c9315e --- /dev/null +++ b/src/database/stream/items.rs @@ -0,0 +1,46 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct Items<'a> { + state: State<'a>, +} + +impl<'a> Items<'a> { + pub(crate) fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for Items<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option<KeyVal<'a>> { self.state.fetch().map(keyval_longevity) } + + #[inline] + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Items<'a> { + type Item = Result<KeyVal<'a>>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Items<'_> { + #[inline] + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/items_rev.rs b/src/database/stream/items_rev.rs new file mode 100644 index 0000000000000000000000000000000000000000..26492db8ccb9331b566f8777f0e8988c585daa45 --- /dev/null +++ b/src/database/stream/items_rev.rs @@ -0,0 +1,46 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{keyval_longevity, Cursor, From, State}; +use crate::{keyval::KeyVal, Engine}; + +pub(crate) struct ItemsRev<'a> { + state: State<'a>, +} + +impl<'a> ItemsRev<'a> { + pub(crate) fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, KeyVal<'a>> for ItemsRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + fn fetch(&self) -> Option<KeyVal<'a>> { self.state.fetch().map(keyval_longevity) } + + #[inline] + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for ItemsRev<'a> { + type Item = Result<KeyVal<'a>>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for ItemsRev<'_> { + #[inline] + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys.rs b/src/database/stream/keys.rs new file mode 100644 index 0000000000000000000000000000000000000000..91884c8dce748335309fb8baf7791748a4e7d0ba --- /dev/null +++ b/src/database/stream/keys.rs @@ -0,0 +1,47 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct Keys<'a> { + state: State<'a>, +} + +impl<'a> Keys<'a> { + pub(crate) fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_fwd(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for Keys<'a> { + fn state(&self) -> &State<'a> { &self.state } + + #[inline] + fn fetch(&self) -> Option<Key<'a>> { self.state.fetch_key().map(slice_longevity) } + + #[inline] + fn seek(&mut self) { self.state.seek_fwd(); } +} + +impl<'a> Stream for Keys<'a> { + type Item = Result<Key<'a>>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for Keys<'_> { + #[inline] + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/stream/keys_rev.rs b/src/database/stream/keys_rev.rs new file mode 100644 index 0000000000000000000000000000000000000000..59f66c2e59b4d87d58adf5c1059cfd37fdbabe8a --- /dev/null +++ b/src/database/stream/keys_rev.rs @@ -0,0 +1,47 @@ +use std::{pin::Pin, sync::Arc}; + +use conduit::Result; +use futures::{ + stream::FusedStream, + task::{Context, Poll}, + Stream, +}; +use rocksdb::{ColumnFamily, ReadOptions}; + +use super::{slice_longevity, Cursor, From, State}; +use crate::{keyval::Key, Engine}; + +pub(crate) struct KeysRev<'a> { + state: State<'a>, +} + +impl<'a> KeysRev<'a> { + pub(crate) fn new(db: &'a Arc<Engine>, cf: &'a Arc<ColumnFamily>, opts: ReadOptions, from: From<'_>) -> Self { + Self { + state: State::new(db, cf, opts).init_rev(from), + } + } +} + +impl<'a> Cursor<'a, Key<'a>> for KeysRev<'a> { + fn state(&self) -> &State<'a> { &self.state } + + #[inline] + fn fetch(&self) -> Option<Key<'a>> { self.state.fetch_key().map(slice_longevity) } + + #[inline] + fn seek(&mut self) { self.state.seek_rev(); } +} + +impl<'a> Stream for KeysRev<'a> { + type Item = Result<Key<'a>>; + + fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + Poll::Ready(self.seek_and_get()) + } +} + +impl FusedStream for KeysRev<'_> { + #[inline] + fn is_terminated(&self) -> bool { !self.state.init && !self.state.valid() } +} diff --git a/src/database/tests.rs b/src/database/tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..bfab99ef0d503ad76132c797b575c04c41d6c785 --- /dev/null +++ b/src/database/tests.rs @@ -0,0 +1,292 @@ +#![cfg(test)] +#![allow(clippy::needless_borrows_for_generic_args)] + +use std::fmt::Debug; + +use arrayvec::ArrayVec; +use conduit::ruma::{serde::Raw, RoomId, UserId}; +use serde::Serialize; + +use crate::{ + de, ser, + ser::{serialize_to_vec, Json}, + Ignore, Interfix, +}; + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_str() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let s = serialize_to_vec(&user_id).expect("failed to serialize user_id"); + assert_eq!(&s, user_id.as_bytes()); +} + +#[test] +fn ser_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let mut a = user_id.as_bytes().to_vec(); + a.push(0xFF); + a.extend_from_slice(room_id.as_bytes()); + + let b = (user_id, room_id); + let b = serialize_to_vec(&b).expect("failed to serialize tuple"); + + assert_eq!(a, b); +} + +#[test] +#[should_panic(expected = "I/O error: failed to write whole buffer")] +fn ser_overflow() { + const BUFSIZE: usize = 10; + + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + assert!(BUFSIZE < user_id.as_str().len() + room_id.as_str().len()); + let mut buf = ArrayVec::<u8, BUFSIZE>::new(); + + let val = (user_id, room_id); + _ = ser::serialize(&mut buf, val).unwrap(); +} + +#[test] +fn ser_complex() { + use conduit::ruma::Mxc; + + #[derive(Debug, Serialize)] + struct Dim { + width: u32, + height: u32, + } + + let mxc = Mxc { + server_name: "example.com".try_into().unwrap(), + media_id: "AbCdEfGhIjK", + }; + + let dim = Dim { + width: 123, + height: 456, + }; + + let mut a = Vec::new(); + a.extend_from_slice(b"mxc://"); + a.extend_from_slice(mxc.server_name.as_bytes()); + a.extend_from_slice(b"/"); + a.extend_from_slice(mxc.media_id.as_bytes()); + a.push(0xFF); + a.extend_from_slice(&dim.width.to_be_bytes()); + a.extend_from_slice(&dim.height.to_be_bytes()); + a.push(0xFF); + + let d: &[u32] = &[dim.width, dim.height]; + let b = (mxc, d, Interfix); + let b = serialize_to_vec(b).expect("failed to serialize complex"); + + assert_eq!(a, b); +} + +#[test] +fn ser_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let serialized = serialize_to_vec(Json(&filter)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_value() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::to_value(filter).expect("failed to serialize to serde_json::value"); + let serialized = serialize_to_vec(Json(value)).expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn ser_json_macro() { + use serde_json::json; + + #[derive(Serialize)] + struct Foo { + foo: String, + } + + let content = Foo { + foo: "bar".to_owned(), + }; + let content = serde_json::to_value(content).expect("failed to serialize content"); + let sender: &UserId = "@foo:example.com".try_into().unwrap(); + let serialized = serialize_to_vec(Json(json!({ + "sender": sender, + "content": content, + }))) + .expect("failed to serialize value"); + + let s = String::from_utf8_lossy(&serialized); + assert_eq!(&s, r#"{"content":{"foo":"bar"},"sender":"@foo:example.com"}"#); +} + +#[test] +#[should_panic(expected = "serializing string at the top-level")] +fn ser_json_raw() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(value.get()).expect("failed to serialize raw value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +#[should_panic(expected = "you can skip serialization instead")] +fn ser_json_raw_json() { + use conduit::ruma::api::client::filter::FilterDefinition; + + let filter = FilterDefinition { + event_fields: Some(vec!["content.body".to_owned()]), + ..Default::default() + }; + + let value = serde_json::value::to_raw_value(&filter).expect("failed to serialize to raw value"); + let a = serialize_to_vec(Json(value)).expect("failed to serialize json value"); + let s = String::from_utf8_lossy(&a); + assert_eq!(&s, r#"{"event_fields":["content.body"]}"#); +} + +#[test] +fn de_tuple() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF!room:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_invalid() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF@user:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_incomplete() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com"; + let (a, _): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_incomplete_with_sep() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF"; + let (a, _): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); +} + +#[test] +#[should_panic(expected = "deserialization failed to consume trailing bytes")] +fn de_tuple_unfinished() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF!room:example.com\xFF@user:example.com"; + let (a, b): (&UserId, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, room_id, "deserialized room_id does not match"); +} + +#[test] +fn de_tuple_ignore() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com\xFF@user2:example.net\xFF!room:example.com"; + let (a, _, c): (&UserId, Ignore, &RoomId) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(c, room_id, "deserialized room_id does not match"); +} + +#[test] +fn de_json_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw<Vec<Raw<String>>> = de::from_slice(&s).expect("failed to deserialize"); + + let d: Vec<String> = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + assert_eq!(*a, d[i]); + } +} + +#[test] +fn de_json_raw_array() { + let a = &["foo", "bar", "baz"]; + let s = serde_json::to_vec(a).expect("failed to serialize to JSON array"); + + let b: Raw<Vec<Raw<String>>> = de::from_slice(&s).expect("failed to deserialize"); + + let c: Vec<Raw<String>> = serde_json::from_str(b.json().get()).expect("failed to deserialize JSON"); + + for (i, a) in a.iter().enumerate() { + let c = serde_json::to_value(c[i].json()).expect("failed to deserialize JSON to string"); + assert_eq!(*a, c); + } +} + +#[test] +fn ser_array() { + let a: u64 = 123_456; + let b: u64 = 987_654; + + let arr: &[u64] = &[a, b]; + + let mut v = Vec::new(); + v.extend_from_slice(&a.to_be_bytes()); + v.extend_from_slice(&b.to_be_bytes()); + + let s = serialize_to_vec(arr).expect("failed to serialize"); + assert_eq!(&s, &v, "serialization does not match"); +} diff --git a/src/database/util.rs b/src/database/util.rs index f0ccbcbee045f793cddfa37ab92333ca9ded3ddd..ae0763812dd6e3d710754c1fecdac3669f05bc5d 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,39 @@ use conduit::{err, Result}; +use rocksdb::{Direction, IteratorMode}; + +//#[cfg(debug_assertions)] +macro_rules! unhandled { + ($msg:literal) => { + unimplemented!($msg) + }; +} + +// activate when stable; we're not ready for this yet +#[cfg(disable)] // #[cfg(not(debug_assertions))] +macro_rules! unhandled { + ($msg:literal) => { + // SAFETY: Eliminates branches for serializing and deserializing types never + // encountered in the codebase. This can promote optimization and reduce + // codegen. The developer must verify for every invoking callsite that the + // unhandled type is in no way involved and could not possibly be encountered. + unsafe { + std::hint::unreachable_unchecked(); + } + }; +} + +pub(crate) use unhandled; + +#[inline] +pub(crate) fn _into_direction(mode: &IteratorMode<'_>) -> Direction { + use Direction::{Forward, Reverse}; + use IteratorMode::{End, From, Start}; + + match mode { + Start | From(_, Forward) => Forward, + End | From(_, Reverse) => Reverse, + } +} #[inline] pub(crate) fn result<T>(r: std::result::Result<T, rocksdb::Error>) -> Result<T, conduit::Error> { diff --git a/src/macros/config.rs b/src/macros/config.rs new file mode 100644 index 0000000000000000000000000000000000000000..d7f115359382174b04dd585c99f06a456d32e018 --- /dev/null +++ b/src/macros/config.rs @@ -0,0 +1,304 @@ +use std::{ + collections::{HashMap, HashSet}, + fmt::Write as _, + fs::OpenOptions, + io::Write as _, +}; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::ToTokens; +use syn::{ + parse::Parser, punctuated::Punctuated, spanned::Spanned, Error, Expr, ExprLit, Field, Fields, FieldsNamed, + ItemStruct, Lit, Meta, MetaList, MetaNameValue, Type, TypePath, +}; + +use crate::{utils::is_cargo_build, Result}; + +const UNDOCUMENTED: &str = "# This item is undocumented. Please contribute documentation for it."; + +#[allow(clippy::needless_pass_by_value)] +pub(super) fn example_generator(input: ItemStruct, args: &[Meta]) -> Result<TokenStream> { + if is_cargo_build() { + generate_example(&input, args)?; + } + + Ok(input.to_token_stream().into()) +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(unused_variables)] +fn generate_example(input: &ItemStruct, args: &[Meta]) -> Result<()> { + let settings = get_settings(args); + + let filename = settings + .get("filename") + .ok_or_else(|| Error::new(args[0].span(), "missing required 'filename' attribute argument"))?; + + let undocumented = settings + .get("undocumented") + .map_or(UNDOCUMENTED, String::as_str); + + let ignore: HashSet<&str> = settings + .get("ignore") + .map_or("", String::as_str) + .split(' ') + .collect(); + + let section = settings + .get("section") + .ok_or_else(|| Error::new(args[0].span(), "missing required 'section' attribute argument"))?; + + let mut file = OpenOptions::new() + .write(true) + .create(section == "global") + .truncate(section == "global") + .append(section != "global") + .open(filename) + .map_err(|e| Error::new(Span::call_site(), format!("Failed to open config file for generation: {e}")))?; + + if let Some(header) = settings.get("header") { + file.write_all(header.as_bytes()) + .expect("written to config file"); + } + + file.write_fmt(format_args!("\n[{section}]\n")) + .expect("written to config file"); + + if let Fields::Named(FieldsNamed { + named, + .. + }) = &input.fields + { + for field in named { + let Some(ident) = &field.ident else { + continue; + }; + + if ignore.contains(ident.to_string().as_str()) { + continue; + } + + let Some(type_name) = get_type_name(field) else { + continue; + }; + + let doc = get_doc_comment(field) + .unwrap_or_else(|| undocumented.into()) + .trim_end() + .to_owned(); + + let doc = if doc.ends_with('#') { + format!("{doc}\n") + } else { + format!("{doc}\n#\n") + }; + + let default = get_doc_default(field) + .or_else(|| get_default(field)) + .unwrap_or_default(); + + let default = if !default.is_empty() { + format!(" {default}") + } else { + default + }; + + file.write_fmt(format_args!("\n{doc}")) + .expect("written to config file"); + + file.write_fmt(format_args!("#{ident} ={default}\n")) + .expect("written to config file"); + } + } + + if let Some(footer) = settings.get("footer") { + file.write_all(footer.as_bytes()) + .expect("written to config file"); + } + + Ok(()) +} + +fn get_settings(args: &[Meta]) -> HashMap<String, String> { + let mut map = HashMap::new(); + for arg in args { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = arg + else { + continue; + }; + + let Expr::Lit( + ExprLit { + lit: Lit::Str(str), + .. + }, + .., + ) = value + else { + continue; + }; + + let Some(key) = path.segments.iter().next().map(|s| s.ident.clone()) else { + continue; + }; + + map.insert(key.to_string(), str.value()); + } + + map +} + +fn get_default(field: &Field) -> Option<String> { + for attr in &field.attrs { + let Meta::List(MetaList { + path, + tokens, + .. + }) = &attr.meta + else { + continue; + }; + + if path + .segments + .iter() + .next() + .is_none_or(|s| s.ident != "serde") + { + continue; + } + + let Some(arg) = Punctuated::<Meta, syn::Token![,]>::parse_terminated + .parse(tokens.clone().into()) + .ok()? + .iter() + .next() + .cloned() + else { + continue; + }; + + match arg { + Meta::NameValue(MetaNameValue { + value: Expr::Lit(ExprLit { + lit: Lit::Str(str), + .. + }), + .. + }) => { + match str.value().as_str() { + "HashSet::new" | "Vec::new" | "RegexSet::empty" => Some("[]".to_owned()), + "true_fn" => return Some("true".to_owned()), + _ => return None, + }; + }, + Meta::Path { + .. + } => return Some("false".to_owned()), + _ => return None, + }; + } + + None +} + +fn get_doc_default(field: &Field) -> Option<String> { + for attr in &field.attrs { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = &attr.meta + else { + continue; + }; + + if path.segments.iter().next().is_none_or(|s| s.ident != "doc") { + continue; + } + + let Expr::Lit(ExprLit { + lit, + .. + }) = &value + else { + continue; + }; + + let Lit::Str(token) = &lit else { + continue; + }; + + let value = token.value(); + if !value.trim().starts_with("default:") { + continue; + } + + return value + .split_once(':') + .map(|(_, v)| v) + .map(str::trim) + .map(ToOwned::to_owned); + } + + None +} + +fn get_doc_comment(field: &Field) -> Option<String> { + let mut out = String::new(); + for attr in &field.attrs { + let Meta::NameValue(MetaNameValue { + path, + value, + .. + }) = &attr.meta + else { + continue; + }; + + if path.segments.iter().next().is_none_or(|s| s.ident != "doc") { + continue; + } + + let Expr::Lit(ExprLit { + lit, + .. + }) = &value + else { + continue; + }; + + let Lit::Str(token) = &lit else { + continue; + }; + + let value = token.value(); + if value.trim().starts_with("default:") { + continue; + } + + writeln!(&mut out, "#{value}").expect("wrote to output string buffer"); + } + + (!out.is_empty()).then_some(out) +} + +fn get_type_name(field: &Field) -> Option<String> { + let Type::Path(TypePath { + path, + .. + }) = &field.ty + else { + return None; + }; + + path.segments + .iter() + .next() + .map(|segment| segment.ident.to_string()) +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index d32cda71c2172f8a441cfee19ba929cc69f62a51..1aa1e24fd3c63e84141b64b563b0a56b33009110 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,5 +1,6 @@ mod admin; mod cargo; +mod config; mod debug; mod implement; mod refutable; @@ -9,7 +10,7 @@ use proc_macro::TokenStream; use syn::{ parse::{Parse, Parser}, - parse_macro_input, Error, Item, ItemConst, ItemEnum, ItemFn, Meta, + parse_macro_input, Error, Item, ItemConst, ItemEnum, ItemFn, ItemStruct, Meta, }; pub(crate) type Result<T> = std::result::Result<T, Error>; @@ -47,6 +48,11 @@ pub fn implement(args: TokenStream, input: TokenStream) -> TokenStream { attribute_macro::<ItemFn, _>(args, input, implement::implement) } +#[proc_macro_attribute] +pub fn config_example_generator(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<ItemStruct, _>(args, input, config::example_generator) +} + fn attribute_macro<I, F>(args: TokenStream, input: TokenStream, func: F) -> TokenStream where F: Fn(I, &[Meta]) -> Result<TokenStream>, diff --git a/src/macros/utils.rs b/src/macros/utils.rs index 58074e3a0ca0fcda7bb37465d6ec8fbdba9fd1d8..e4ffc622d44ed31c1c6868c3fed085240e4c1dc3 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -2,6 +2,16 @@ use crate::Result; +pub(crate) fn is_cargo_build() -> bool { + std::env::args() + .find(|flag| flag.starts_with("--emit")) + .as_ref() + .and_then(|flag| flag.split_once('=')) + .map(|val| val.1.split(',')) + .and_then(|mut vals| vals.find(|elem| *elem == "link")) + .is_some() +} + pub(crate) fn get_named_generics(args: &[Meta], name: &str) -> Result<Generics> { const DEFAULT: &str = "<>"; @@ -41,8 +51,5 @@ pub(crate) fn camel_to_snake_string(s: &str) -> String { output } -pub(crate) fn exchange<T: Clone>(state: &mut T, source: T) -> T { - let ret = state.clone(); - *state = source; - ret -} +#[inline] +pub(crate) fn exchange<T>(state: &mut T, source: T) -> T { std::mem::replace(state, source) } diff --git a/src/main/Cargo.toml b/src/main/Cargo.toml index b3390bfb1adbb61e993838fd346f8bd097967d1b..b91229425ba104206da9ffa473539e0367441bcb 100644 --- a/src/main/Cargo.toml +++ b/src/main/Cargo.toml @@ -44,7 +44,6 @@ default = [ "jemalloc", "jemalloc_stats", "release_max_log_level", - "sentry_telemetry", "systemd", "zstd_compression", ] diff --git a/src/main/main.rs b/src/main/main.rs index 8703eef2bcb72082ec71af539ad6b7682712e949..8e644a1583c5d975f1cdabc18203e2134d37dd5a 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "192"] - pub(crate) mod clap; mod mods; mod restart; diff --git a/src/main/server.rs b/src/main/server.rs index e435b2f4410e573bb5dc388876c14fc7e61ba420..4813d586c43a063c67f5bb7aa75aa9b0e37b61a0 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -24,7 +24,7 @@ pub(crate) struct Server { impl Server { pub(crate) fn build(args: &Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Self>, Error> { - let raw_config = Config::load(&args.config)?; + let raw_config = Config::load(args.config.as_deref())?; let raw_config = crate::clap::update(raw_config, args)?; let config = Config::new(&raw_config)?; diff --git a/src/main/tracing.rs b/src/main/tracing.rs index 9b4ad659d7dcded1da4d8c560b024f6979226508..c28fef6b860dcdbe850119e5f9eef5f792d2f685 100644 --- a/src/main/tracing.rs +++ b/src/main/tracing.rs @@ -3,7 +3,8 @@ use conduit::{ config::Config, debug_warn, err, - log::{capture, LogLevelReloadHandles}, + log::{capture, fmt_span, LogLevelReloadHandles}, + result::UnwrapOrErr, Result, }; use tracing_subscriber::{layer::SubscriberExt, reload, EnvFilter, Layer, Registry}; @@ -18,7 +19,10 @@ pub(crate) fn init(config: &Config) -> Result<(LogLevelReloadHandles, TracingFla let reload_handles = LogLevelReloadHandles::default(); let console_filter = EnvFilter::try_new(&config.log).map_err(|e| err!(Config("log", "{e}.")))?; - let console_layer = tracing_subscriber::fmt::Layer::new().with_ansi(config.log_colors); + let console_span_events = fmt_span::from_str(&config.log_span_events).unwrap_or_err(); + let console_layer = tracing_subscriber::fmt::Layer::new() + .with_ansi(config.log_colors) + .with_span_events(console_span_events); let (console_reload_filter, console_reload_handle) = reload::Layer::new(console_filter.clone()); reload_handles.add("console", Box::new(console_reload_handle)); diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index 62690194e01fa0f37ca42029e8d5f2b7e446a076..e15358687a2a6a4bb261c13d7f67692f504f7b68 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -54,20 +54,18 @@ axum-server-dual-protocol.workspace = true axum-server-dual-protocol.optional = true axum-server.workspace = true axum.workspace = true +bytes.workspace = true conduit-admin.workspace = true conduit-api.workspace = true conduit-core.workspace = true conduit-service.workspace = true const-str.workspace = true -log.workspace = true -tokio.workspace = true -tower.workspace = true -tracing.workspace = true -bytes.workspace = true -http-body-util.workspace = true +futures.workspace = true http.workspace = true +http-body-util.workspace = true hyper.workspace = true hyper-util.workspace = true +log.workspace = true ruma.workspace = true rustls.workspace = true rustls.optional = true @@ -78,7 +76,10 @@ sentry-tracing.optional = true sentry-tracing.workspace = true sentry.workspace = true serde_json.workspace = true +tokio.workspace = true +tower.workspace = true tower-http.workspace = true +tracing.workspace = true [target.'cfg(unix)'.dependencies] sd-notify.workspace = true diff --git a/src/router/layers.rs b/src/router/layers.rs index a1a70bb86500748db48f8a52b71d318cb128d7dd..fd68cc36750c109130362373f604a9ae45ab71cc 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -24,15 +24,15 @@ use crate::{request, router}; -const CONDUWUIT_CSP: &[&str] = &[ - "sandbox", +const CONDUWUIT_CSP: &[&str; 5] = &[ "default-src 'none'", "frame-ancestors 'none'", "form-action 'none'", "base-uri 'none'", + "sandbox", ]; -const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; +const CONDUWUIT_PERMISSIONS_POLICY: &[&str; 2] = &["interest-cohort=()", "browsing-topics=()"]; pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> { let server = &services.server; @@ -78,7 +78,7 @@ pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> { )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_str(&CONDUWUIT_CSP.join("; "))?, + HeaderValue::from_str(&CONDUWUIT_CSP.join(";"))?, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) @@ -184,12 +184,20 @@ fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_u } fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span { - let path = request - .extensions() - .get::<MatchedPath>() - .map_or_else(|| request.uri().path(), truncated_matched_path); - - tracing::info_span!("router:", %path) + let path = request.extensions().get::<MatchedPath>().map_or_else( + || { + request + .uri() + .path_and_query() + .expect("all requests have a path") + .as_str() + }, + truncated_matched_path, + ); + + let method = request.method(); + + tracing::info_span!("router:", %method, %path) } fn truncated_matched_path(path: &MatchedPath) -> &str { diff --git a/src/router/mod.rs b/src/router/mod.rs index 67ebc0e3f31ed96eeb39472c470bc64d7fc39264..1580f605182a56dc5e1a21d0b2f70679df895742 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,5 +1,3 @@ -#![recursion_limit = "160"] - mod layers; mod request; mod router; @@ -8,10 +6,11 @@ extern crate conduit_core as conduit; -use std::{future::Future, pin::Pin, sync::Arc}; +use std::{panic::AssertUnwindSafe, pin::Pin, sync::Arc}; -use conduit::{Result, Server}; +use conduit::{Error, Result, Server}; use conduit_service::Services; +use futures::{Future, FutureExt, TryFutureExt}; conduit::mod_ctor! {} conduit::mod_dtor! {} @@ -19,15 +18,27 @@ #[no_mangle] pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<Arc<Services>>> + Send>> { - Box::pin(run::start(server.clone())) + AssertUnwindSafe(run::start(server.clone())) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } #[no_mangle] pub extern "Rust" fn stop(services: Arc<Services>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { - Box::pin(run::stop(services)) + AssertUnwindSafe(run::stop(services)) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } #[no_mangle] pub extern "Rust" fn run(services: &Arc<Services>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { - Box::pin(run::run(services.clone())) + AssertUnwindSafe(run::run(services.clone())) + .catch_unwind() + .map_err(Error::from_panic) + .unwrap_or_else(Err) + .boxed() } diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 08263353bf0adb36cc16a0abe24b3934f6980284..144bff85da0bab6a50196a8882af106ba2329faf 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -5,9 +5,8 @@ use axum::Router; use axum_server::{bind, Handle as ServerHandle}; -use conduit::{debug_info, Result, Server}; +use conduit::{debug_info, info, Result, Server}; use tokio::task::JoinSet; -use tracing::info; pub(super) async fn serve( server: &Arc<Server>, app: Router, handle: ServerHandle, addrs: Vec<SocketAddr>, diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index fb011f1883199a8578f1e2765d0fa41b1100be4f..5df41b61432d0f1ef86b10248959a7887c58cdae 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -10,7 +10,7 @@ extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, Router, }; -use conduit::{debug, debug_error, error::infallible, info, trace, warn, Err, Result, Server}; +use conduit::{debug, debug_error, info, result::UnwrapInfallible, trace, warn, Err, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -62,11 +62,7 @@ async fn accept( let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = app - .call(NULL_ADDR) - .await - .inspect_err(infallible) - .expect("infallible"); + let called = app.call(NULL_ADDR).await.unwrap_infallible(); let service = move |req: Request<Incoming>| called.clone().oneshot(req); let handler = service_fn(service); diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index cfed5a0e3ee30310b37dc24f0214d89c9c70ec4f..7578ef64f97b3307ad6507254ba137e869b40e37 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -40,13 +40,14 @@ release_max_log_level = [ ] [dependencies] +arrayvec.workspace = true async-trait.workspace = true base64.workspace = true bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true const-str.workspace = true -futures-util.workspace = true +futures.workspace = true hickory-resolver.workspace = true http.workspace = true image.workspace = true diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs deleted file mode 100644 index 53a0e953388a54681c5e7b198089ce11c0740fe8..0000000000000000000000000000000000000000 --- a/src/service/account_data/data.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use conduit::{Error, Result}; -use database::Map; -use ruma::{ - api::client::error::ErrorKind, - events::{AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - roomuserdataid_accountdata: Arc<Map>, - roomusertype_roomuserdataid: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - }, - } - } - - /// Places one event in the account data of the user and removes the - /// previous entry. - pub(super) fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xFF); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); - - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } - - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; - - let prev = self.roomusertype_roomuserdataid.get(&key)?; - - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } - - Ok(()) - } - - /// Searches the account data for a specific kind. - pub(super) fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType, - ) -> Result<Option<Box<serde_json::value::RawValue>>> { - let mut key = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(kind.to_string().as_bytes()); - - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) - .transpose() - } - - /// Returns all changes to the account data that happened after `since`. - pub(super) fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result<Vec<AnyRawAccountDataEvent>> { - let mut userdata = HashMap::new(); - - let mut prefix = room_id - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xFF); - - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); - - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - k, - match room_id { - None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(&v) - .map(AnyRawAccountDataEvent::Global) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(&v) - .map(AnyRawAccountDataEvent::Room) - .map_err(|_| Error::bad_database("Database contains invalid account data."))?, - }, - )) - }) { - let (kind, data) = r?; - userdata.insert(kind, data); - } - - Ok(userdata.into_values().collect()) - } -} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index eaa536417c9ae2f5b3410b70c801a4860ffcb2c8..ac3f5f83e569ce72d6173d3088430e89b22002ec 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,52 +1,157 @@ -mod data; +use std::{collections::HashMap, sync::Arc}; -use std::sync::Arc; - -use conduit::Result; -use data::Data; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Handle, Json, Map}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ - events::{AnyRawAccountDataEvent, RoomAccountDataEventType}, + events::{ + AnyGlobalAccountDataEvent, AnyRawAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType, + RoomAccountDataEventType, + }, + serde::Raw, RoomId, UserId, }; +use serde::Deserialize; + +use crate::{globals, Dep}; pub struct Service { + services: Services, db: Data, } +struct Data { + roomuserdataid_accountdata: Arc<Map>, + roomusertype_roomuserdataid: Arc<Map>, +} + +struct Services { + globals: Dep<globals::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, + db: Data { + roomuserdataid_accountdata: args.db["roomuserdataid_accountdata"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Places one event in the account data of the user and removes the - /// previous entry. - #[allow(clippy::needless_pass_by_value)] - pub fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - self.db.update(room_id, user_id, &event_type, data) +/// Places one event in the account data of the user and removes the +/// previous entry. +#[allow(clippy::needless_pass_by_value)] +#[implement(Service)] +pub async fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, +) -> Result<()> { + if data.get("type").is_none() || data.get("content").is_none() { + return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - /// Searches the account data for a specific kind. - #[allow(clippy::needless_pass_by_value)] - pub fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - ) -> Result<Option<Box<serde_json::value::RawValue>>> { - self.db.get(room_id, user_id, &event_type) - } + let count = self.services.globals.next_count().unwrap(); + let roomuserdataid = (room_id, user_id, count, &event_type); + self.db + .roomuserdataid_accountdata + .put(roomuserdataid, Json(data)); + + let key = (room_id, user_id, &event_type); + let prev = self.db.roomusertype_roomuserdataid.qry(&key).await; + self.db.roomusertype_roomuserdataid.put(key, roomuserdataid); - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since", level = "debug")] - pub fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result<Vec<AnyRawAccountDataEvent>> { - self.db.changes_since(room_id, user_id, since) + // Remove old entry + if let Ok(prev) = prev { + self.db.roomuserdataid_accountdata.remove(&prev); } + + Ok(()) +} + +/// Searches the room account data for a specific kind. +#[implement(Service)] +pub async fn get_global<T>(&self, user_id: &UserId, kind: GlobalAccountDataEventType) -> Result<T> +where + T: for<'de> Deserialize<'de>, +{ + self.get_raw(None, user_id, &kind.to_string()) + .await + .deserialized() +} + +/// Searches the global account data for a specific kind. +#[implement(Service)] +pub async fn get_room<T>(&self, room_id: &RoomId, user_id: &UserId, kind: RoomAccountDataEventType) -> Result<T> +where + T: for<'de> Deserialize<'de>, +{ + self.get_raw(Some(room_id), user_id, &kind.to_string()) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_raw(&self, room_id: Option<&RoomId>, user_id: &UserId, kind: &str) -> Result<Handle<'_>> { + let key = (room_id, user_id, kind.to_owned()); + self.db + .roomusertype_roomuserdataid + .qry(&key) + .and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.get(&roomuserdataid)) + .await +} + +/// Returns all changes to the account data that happened after `since`. +#[implement(Service)] +pub async fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, +) -> Result<Vec<AnyRawAccountDataEvent>> { + let mut userdata = HashMap::new(); + + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); + + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); + + self.db + .roomuserdataid_accountdata + .raw_stream_from(&first_possible) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + let v = match room_id { + None => serde_json::from_slice::<Raw<AnyGlobalAccountDataEvent>>(v) + .map(AnyRawAccountDataEvent::Global) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + Some(_) => serde_json::from_slice::<Raw<AnyRoomAccountDataEvent>>(v) + .map(AnyRawAccountDataEvent::Room) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + }; + + Ok((k.to_owned(), v)) + }) + .ignore_err() + .ready_for_each(|(kind, data)| { + userdata.insert(kind, data); + }) + .await; + + Ok(userdata.into_values().collect()) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 55bae36582608f19de61c11c0aa46cc82b6205b7..0f5016e152e5391236e9ece74cf0d262cd37d111 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -5,7 +5,7 @@ }; use conduit::{debug, defer, error, log, Server}; -use futures_util::future::{AbortHandle, Abortable}; +use futures::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; use termimad::MadSkin; diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index 4e2b831c5f450ebb676d84b38fb0a3bc018ead31..1631f1cbb920dc1807ff3c7d5a0e9e0eddfe7b7a 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -2,24 +2,20 @@ use conduit::{pdu::PduBuilder, Result}; use ruma::{ - events::{ - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - preview_url::RoomPreviewUrlsEventContent, - topic::RoomTopicEventContent, - }, - TimelineEventType, + events::room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + preview_url::RoomPreviewUrlsEventContent, + topic::RoomTopicEventContent, }, RoomId, RoomVersionId, }; -use serde_json::value::to_raw_value; use crate::Services; @@ -30,7 +26,11 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services + .rooms + .short + .get_or_create_shortroomid(&room_id) + .await; let state_lock = services.rooms.state.mutex.lock(&room_id).await; @@ -40,7 +40,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { let room_version = services.globals.default_room_version(); - let mut content = { + let create_content = { use RoomVersionId::*; match room_version { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()), @@ -48,23 +48,20 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { } }; - content.federate = true; - content.predecessor = None; - content.room_version = room_version; - // 1. The room create event services .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomCreateEventContent { + federate: true, + predecessor: None, + room_version, + ..create_content + }, + ), server_user, &room_id, &state_lock, @@ -76,24 +73,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(server_user.to_string()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(server_user.to_string(), &RoomMemberEventContent::new(MembershipState::Join)), server_user, &room_id, &state_lock, @@ -107,18 +87,13 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { users, ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -130,15 +105,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomJoinRulesEventContent::new(JoinRule::Invite)), server_user, &room_id, &state_lock, @@ -150,15 +117,10 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state( + String::new(), + &RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared), + ), server_user, &room_id, &state_lock, @@ -170,15 +132,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomGuestAccessEventContent::new(GuestAccess::Forbidden)), server_user, &room_id, &state_lock, @@ -191,15 +145,7 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(room_name)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + PduBuilder::state(String::new(), &RoomNameEventContent::new(room_name)), server_user, &room_id, &state_lock, @@ -210,17 +156,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { + PduBuilder::state( + String::new(), + &RoomTopicEventContent { topic: format!("Manage {}", services.globals.server_name()), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -234,18 +175,13 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { + PduBuilder::state( + String::new(), + &RoomCanonicalAliasEventContent { alias: Some(alias.clone()), alt_aliases: Vec::new(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, @@ -262,17 +198,12 @@ pub async fn create_admin_room(services: &Services) -> Result<()> { .rooms .timeline .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPreviewUrls, - content: to_raw_value(&RoomPreviewUrlsEventContent { + PduBuilder::state( + String::new(), + &RoomPreviewUrlsEventContent { disabled: true, - }) - .expect("event is valid we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, - }, + }, + ), server_user, &room_id, &state_lock, diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index b4589ebc8a8911bf4752fb677f9128491d32db49..405da982e9af40451512c81484ef9fdf64707bee 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -9,145 +9,98 @@ power_levels::RoomPowerLevelsEventContent, }, tag::{TagEvent, TagEventContent, TagInfo}, - RoomAccountDataEventType, TimelineEventType, + RoomAccountDataEventType, }, RoomId, UserId, }; -use serde_json::value::to_raw_value; use crate::pdu::PduBuilder; -impl super::Service { - /// Invite the user to the conduit admin room. - /// - /// In conduit, this is equivalent to granting admin privileges. - pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { - let Some(room_id) = self.get_admin_room()? else { - return Ok(()); - }; - - let state_lock = self.services.state.mutex.lock(&room_id).await; - - // Use the server user to grant the new admin's power level - let server_user = &self.services.globals.server_user; - - // Invite and join the real user - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ) - .await?; - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - timestamp: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; - - // Set power level - let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); - - self.services - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(String::new()), - redacts: None, - timestamp: None, +/// Invite the user to the conduit admin room. +/// +/// In conduit, this is equivalent to granting admin privileges. +#[implement(super::Service)] +pub async fn make_user_admin(&self, user_id: &UserId) -> Result<()> { + let Ok(room_id) = self.get_admin_room().await else { + return Ok(()); + }; + + let state_lock = self.services.state.mutex.lock(&room_id).await; + + // Use the server user to grant the new admin's power level + let server_user = &self.services.globals.server_user; + + // Invite and join the real user + self.services + .timeline + .build_and_append_pdu( + PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Invite)), + server_user, + &room_id, + &state_lock, + ) + .await?; + self.services + .timeline + .build_and_append_pdu( + PduBuilder::state(user_id.to_string(), &RoomMemberEventContent::new(MembershipState::Join)), + user_id, + &room_id, + &state_lock, + ) + .await?; + + // Set power level + let users = BTreeMap::from_iter([(server_user.clone(), 100.into()), (user_id.to_owned(), 100.into())]); + + self.services + .timeline + .build_and_append_pdu( + PduBuilder::state( + String::new(), + &RoomPowerLevelsEventContent { + users, + ..Default::default() }, - server_user, - &room_id, - &state_lock, - ) - .await?; - - // Set room tag - let room_tag = &self.services.server.config.admin_room_tag; - if !room_tag.is_empty() { - if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag) { - error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); - } - } + ), + server_user, + &room_id, + &state_lock, + ) + .await?; - // Send welcome message - self.services.timeline.build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&RoomMessageEventContent::text_markdown( - String::from("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `!admin --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`"), - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }, - server_user, - &room_id, - &state_lock, - ).await?; - - Ok(()) + // Set room tag + let room_tag = &self.services.server.config.admin_room_tag; + if !room_tag.is_empty() { + if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await { + error!(?room_id, ?user_id, ?room_tag, ?e, "Failed to set tag for admin grant"); + } } + + let welcome_message = String::from("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `!admin --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`"); + + // Send welcome message + self.services + .timeline + .build_and_append_pdu( + PduBuilder::timeline(&RoomMessageEventContent::text_markdown(welcome_message)), + server_user, + &room_id, + &state_lock, + ) + .await?; + + Ok(()) } #[implement(super::Service)] -fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { +async fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result<()> { let mut event = self .services .account_data - .get(Some(room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| serde_json::from_str(event.get())) - .and_then(Result::ok) - .unwrap_or_else(|| TagEvent { + .get_room(room_id, user_id, RoomAccountDataEventType::Tag) + .await + .unwrap_or_else(|_| TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, @@ -158,12 +111,15 @@ fn set_room_tag(&self, room_id: &RoomId, user_id: &UserId, tag: &str) -> Result< .tags .insert(tag.to_owned().into(), TagInfo::new()); - self.services.account_data.update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &serde_json::to_value(event)?, - )?; + self.services + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &serde_json::to_value(event)?, + ) + .await?; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 3274249e6060f7e09c2d2628768a42059a4114e7..2860bd1bb480468491c0a8fbedf464eaee2960b0 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -12,15 +12,12 @@ use async_trait::async_trait; use conduit::{debug, err, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; pub use create::create_admin_room; +use futures::{FutureExt, TryFutureExt}; use loole::{Receiver, Sender}; use ruma::{ - events::{ - room::message::{Relation, RoomMessageEventContent}, - TimelineEventType, - }, + events::room::message::{Relation, RoomMessageEventContent}, OwnedEventId, OwnedRoomId, RoomId, UserId, }; -use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; use crate::{account_data, globals, rooms, rooms::state::RoomMutexGuard, Dep}; @@ -142,17 +139,18 @@ impl Service { /// admin room as the admin user. pub async fn send_text(&self, body: &str) { self.send_message(RoomMessageEventContent::text_markdown(body)) - .await; + .await + .ok(); } /// Sends a message to the admin room as the admin user (see send_text() for /// convenience). - pub async fn send_message(&self, message_content: RoomMessageEventContent) { - if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.services.globals.server_user; - self.respond_to_room(message_content, &room_id, user_id) - .await; - } + pub async fn send_message(&self, message_content: RoomMessageEventContent) -> Result<()> { + let user_id = &self.services.globals.server_user; + let room_id = self.get_admin_room().await?; + self.respond_to_room(message_content, &room_id, user_id) + .boxed() + .await } /// Posts a command to the command processor queue and returns. Processing @@ -193,8 +191,11 @@ async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { async fn handle_command(&self, command: CommandInput) { match self.process_command(command).await { - Ok(Some(output)) | Err(output) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), + Ok(Some(output)) | Err(output) => self + .handle_response(output) + .await + .unwrap_or_else(default_log), } } @@ -218,113 +219,96 @@ async fn process_command(&self, command: CommandInput) -> ProcessorResult { } /// Checks whether a given user is an admin of this server - pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { - if let Ok(Some(admin_room)) = self.get_admin_room() { - self.services.state_cache.is_joined(user_id, &admin_room) - } else { - Ok(false) - } + pub async fn user_is_admin(&self, user_id: &UserId) -> bool { + let Ok(admin_room) = self.get_admin_room().await else { + return false; + }; + + self.services + .state_cache + .is_joined(user_id, &admin_room) + .await } /// Gets the room ID of the admin room /// /// Errors are propagated from the database, and will have None if there is /// no admin room - pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> { - if let Some(room_id) = self + pub async fn get_admin_room(&self) -> Result<OwnedRoomId> { + let room_id = self .services .alias - .resolve_local_alias(&self.services.globals.admin_alias)? - { - if self - .services - .state_cache - .is_joined(&self.services.globals.server_user, &room_id)? - { - return Ok(Some(room_id)); - } - } + .resolve_local_alias(&self.services.globals.admin_alias) + .await?; - Ok(None) + self.services + .state_cache + .is_joined(&self.services.globals.server_user, &room_id) + .await + .then_some(room_id) + .ok_or_else(|| err!(Request(NotFound("Admin user not joined to admin room")))) } - async fn handle_response(&self, content: RoomMessageEventContent) { + async fn handle_response(&self, content: RoomMessageEventContent) -> Result<()> { let Some(Relation::Reply { in_reply_to, }) = content.relates_to.as_ref() else { - return; + return Ok(()); }; - let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else { error!( event_id = ?in_reply_to.event_id, "Missing admin command in_reply_to event" ); - return; + return Ok(()); }; - let response_sender = if self.is_admin_room(&pdu.room_id) { + let response_sender = if self.is_admin_room(&pdu.room_id).await { &self.services.globals.server_user } else { &pdu.sender }; self.respond_to_room(content, &pdu.room_id, response_sender) - .await; + .boxed() + .await } - async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { - assert!( - self.user_is_admin(user_id) - .await - .expect("checked user is admin"), - "sender is not admin" - ); + async fn respond_to_room( + &self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId, + ) -> Result<()> { + assert!(self.user_is_admin(user_id).await, "sender is not admin"); let state_lock = self.services.state.mutex.lock(room_id).await; - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }; if let Err(e) = self .services .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) + .build_and_append_pdu(PduBuilder::timeline(&content), user_id, room_id, &state_lock) .await { self.handle_response_error(e, room_id, user_id, &state_lock) .await .unwrap_or_else(default_log); } + + Ok(()) } async fn handle_response_error( &self, e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, ) -> Result<()> { error!("Failed to build and append admin room response PDU: \"{e}\""); - let error_room_message = RoomMessageEventContent::text_plain(format!( + let content = RoomMessageEventContent::text_plain(format!( "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ successfully, but we could not return the output." )); - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - timestamp: None, - }; - self.services .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) + .build_and_append_pdu(PduBuilder::timeline(&content), user_id, room_id, state_lock) .await?; Ok(()) @@ -355,12 +339,12 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { } // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !self.is_admin_room(&pdu.room_id) { + if is_public_prefix && !self.is_admin_room(&pdu.room_id).await { return false; } // Only senders who are admin can proceed - if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) { + if !self.user_is_admin(&pdu.sender).await { return false; } @@ -368,7 +352,7 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { // the administrator can execute commands as conduit let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; - if from_server && self.is_admin_room(&pdu.room_id) { + if from_server && self.is_admin_room(&pdu.room_id).await { return false; } @@ -377,19 +361,18 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { } #[must_use] - pub fn is_admin_room(&self, room_id: &RoomId) -> bool { - if let Ok(Some(admin_room_id)) = self.get_admin_room() { - admin_room_id == room_id - } else { - false - } + pub async fn is_admin_room(&self, room_id_: &RoomId) -> bool { + self.get_admin_room() + .map_ok(|room_id| room_id == room_id_) + .await + .unwrap_or(false) } /// Sets the self-reference to crate::Services which will provide context to /// the admin commands. - pub(super) fn set_services(&self, services: &Option<Arc<crate::Services>>) { + pub(super) fn set_services(&self, services: Option<&Arc<crate::Services>>) { let receiver = &mut *self.services.services.write().expect("locked for writing"); - let weak = services.as_ref().map(Arc::downgrade); + let weak = services.map(Arc::downgrade); *receiver = weak; } } diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs deleted file mode 100644 index 40e641a1eaf9118125c298404cdb2090c51deaa1..0000000000000000000000000000000000000000 --- a/src/service/appservice/data.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::api::appservice::Registration; - -pub struct Data { - id_appserviceregistrations: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - id_appserviceregistrations: db["id_appserviceregistrations"].clone(), - } - } - - /// Registers an appservice and returns the ID to the caller - pub(super) fn register_appservice(&self, yaml: &Registration) -> Result<String> { - let id = yaml.id.as_str(); - self.id_appserviceregistrations - .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; - - Ok(id.to_owned()) - } - - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes())?; - Ok(()) - } - - pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes) - .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) - }) - .transpose() - } - - pub(super) fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - }))) - } -} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index c0752d565f7f5ddd8b750f10cc66ba5830d43e6b..1617e6e6e8bf5554422508e2dcf03c0d48e35fce 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,138 +1,50 @@ -mod data; +mod namespace_regex; +mod registration_info; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{err, Result}; -use data::Data; -use futures_util::Future; -use regex::RegexSet; -use ruma::{ - api::appservice::{Namespace, Registration}, - RoomAliasId, RoomId, UserId, -}; +use async_trait::async_trait; +use conduit::{err, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::{Future, StreamExt, TryStreamExt}; +use ruma::{api::appservice::Registration, RoomAliasId, RoomId, UserId}; use tokio::sync::RwLock; +pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; use crate::{sending, Dep}; -/// Compiled regular expressions for a namespace -#[derive(Clone, Debug)] -pub struct NamespaceRegex { - pub exclusive: Option<RegexSet>, - pub non_exclusive: Option<RegexSet>, -} - -impl NamespaceRegex { - /// Checks if this namespace has rights to a namespace - #[inline] - #[must_use] - pub fn is_match(&self, heystack: &str) -> bool { - if self.is_exclusive_match(heystack) { - return true; - } - - if let Some(non_exclusive) = &self.non_exclusive { - if non_exclusive.is_match(heystack) { - return true; - } - } - false - } - - /// Checks if this namespace has exlusive rights to a namespace - #[inline] - #[must_use] - pub fn is_exclusive_match(&self, heystack: &str) -> bool { - if let Some(exclusive) = &self.exclusive { - if exclusive.is_match(heystack) { - return true; - } - } - false - } -} - -impl RegistrationInfo { - #[must_use] - pub fn is_user_match(&self, user_id: &UserId) -> bool { - self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } - - #[inline] - #[must_use] - pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { - self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() - } -} - -impl TryFrom<Vec<Namespace>> for NamespaceRegex { - type Error = regex::Error; - - fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> { - let mut exclusive = Vec::with_capacity(value.len()); - let mut non_exclusive = Vec::with_capacity(value.len()); - - for namespace in value { - if namespace.exclusive { - exclusive.push(namespace.regex); - } else { - non_exclusive.push(namespace.regex); - } - } - - Ok(Self { - exclusive: if exclusive.is_empty() { - None - } else { - Some(RegexSet::new(exclusive)?) - }, - non_exclusive: if non_exclusive.is_empty() { - None - } else { - Some(RegexSet::new(non_exclusive)?) - }, - }) - } -} - -/// Appservice registration combined with its compiled regular expressions. -#[derive(Clone, Debug)] -pub struct RegistrationInfo { - pub registration: Registration, - pub users: NamespaceRegex, - pub aliases: NamespaceRegex, - pub rooms: NamespaceRegex, -} - -impl TryFrom<Registration> for RegistrationInfo { - type Error = regex::Error; - - fn try_from(value: Registration) -> Result<Self, regex::Error> { - Ok(Self { - users: value.namespaces.users.clone().try_into()?, - aliases: value.namespaces.aliases.clone().try_into()?, - rooms: value.namespaces.rooms.clone().try_into()?, - registration: value, - }) - } -} - pub struct Service { - pub db: Data, - services: Services, registration_info: RwLock<BTreeMap<String, RegistrationInfo>>, + services: Services, + db: Data, } struct Services { sending: Dep<sending::Service>, } +struct Data { + id_appserviceregistrations: Arc<Map>, +} + +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { - let mut registration_info = BTreeMap::new(); - let db = Data::new(args.db); + Ok(Arc::new(Self { + registration_info: RwLock::new(BTreeMap::new()), + services: Services { + sending: args.depend::<sending::Service>("sending"), + }, + db: Data { + id_appserviceregistrations: args.db["id_appserviceregistrations"].clone(), + }, + })) + } + + async fn worker(self: Arc<Self>) -> Result<()> { // Inserting registrations into cache - for appservice in iter_ids(&db)? { - registration_info.insert( + for appservice in self.iter_db_ids().await? { + self.registration_info.write().await.insert( appservice.0, appservice .1 @@ -141,22 +53,13 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { ); } - Ok(Arc::new(Self { - db, - services: Services { - sending: args.depend::<sending::Service>("sending"), - }, - registration_info: RwLock::new(registration_info), - })) + Ok(()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[inline] - pub fn all(&self) -> Result<Vec<(String, Registration)>> { iter_ids(&self.db) } - /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { //TODO: Check for collisions between exclusive appservice namespaces @@ -165,7 +68,11 @@ pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { .await .insert(yaml.id.clone(), yaml.clone().try_into()?); - self.db.register_appservice(&yaml) + let id = yaml.id.as_str(); + let yaml = serde_yaml::to_string(&yaml)?; + self.db.id_appserviceregistrations.insert(id, yaml); + + Ok(id.to_owned()) } /// Remove an appservice registration @@ -182,13 +89,14 @@ pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { .ok_or(err!("Appservice not found"))?; // remove the appservice from the database - self.db.unregister_appservice(service_name)?; + self.db.id_appserviceregistrations.remove(service_name); // deletes all active requests for the appservice if there are any so we stop // sending to the URL self.services .sending - .cleanup_events(service_name.to_owned())?; + .cleanup_events(service_name.to_owned()) + .await; Ok(()) } @@ -249,17 +157,29 @@ pub async fn is_exclusive_room_id(&self, room_id: &RoomId) -> bool { pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> { self.registration_info.read() } -} -fn iter_ids(db: &Data) -> Result<Vec<(String, Registration)>> { - db.iter_ids()? - .filter_map(Result::ok) - .map(move |id| { - Ok(( - id.clone(), - db.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() + #[inline] + pub async fn all(&self) -> Result<Vec<(String, Registration)>> { self.iter_db_ids().await } + + pub async fn get_db_registration(&self, id: &str) -> Result<Registration> { + self.db + .id_appserviceregistrations + .get(id) + .await + .and_then(|ref bytes| serde_yaml::from_slice(bytes).map_err(Into::into)) + .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) + } + + async fn iter_db_ids(&self) -> Result<Vec<(String, Registration)>> { + self.db + .id_appserviceregistrations + .keys() + .ignore_err() + .then(|id: String| async move { + let reg = self.get_db_registration(&id).await?; + Ok((id, reg)) + }) + .try_collect() + .await + } } diff --git a/src/service/appservice/namespace_regex.rs b/src/service/appservice/namespace_regex.rs new file mode 100644 index 0000000000000000000000000000000000000000..3529fc0ef7daa69e77ec2f28819d74afc89ec196 --- /dev/null +++ b/src/service/appservice/namespace_regex.rs @@ -0,0 +1,70 @@ +use conduit::Result; +use regex::RegexSet; +use ruma::api::appservice::Namespace; + +/// Compiled regular expressions for a namespace +#[derive(Clone, Debug)] +pub struct NamespaceRegex { + pub exclusive: Option<RegexSet>, + pub non_exclusive: Option<RegexSet>, +} + +impl NamespaceRegex { + /// Checks if this namespace has rights to a namespace + #[inline] + #[must_use] + pub fn is_match(&self, heystack: &str) -> bool { + if self.is_exclusive_match(heystack) { + return true; + } + + if let Some(non_exclusive) = &self.non_exclusive { + if non_exclusive.is_match(heystack) { + return true; + } + } + false + } + + /// Checks if this namespace has exlusive rights to a namespace + #[inline] + #[must_use] + pub fn is_exclusive_match(&self, heystack: &str) -> bool { + if let Some(exclusive) = &self.exclusive { + if exclusive.is_match(heystack) { + return true; + } + } + false + } +} + +impl TryFrom<Vec<Namespace>> for NamespaceRegex { + type Error = regex::Error; + + fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> { + let mut exclusive = Vec::with_capacity(value.len()); + let mut non_exclusive = Vec::with_capacity(value.len()); + + for namespace in value { + if namespace.exclusive { + exclusive.push(namespace.regex); + } else { + non_exclusive.push(namespace.regex); + } + } + + Ok(Self { + exclusive: if exclusive.is_empty() { + None + } else { + Some(RegexSet::new(exclusive)?) + }, + non_exclusive: if non_exclusive.is_empty() { + None + } else { + Some(RegexSet::new(non_exclusive)?) + }, + }) + } +} diff --git a/src/service/appservice/registration_info.rs b/src/service/appservice/registration_info.rs new file mode 100644 index 0000000000000000000000000000000000000000..2c8595b1b18eefc287c190d8d88be72b0c33a1eb --- /dev/null +++ b/src/service/appservice/registration_info.rs @@ -0,0 +1,39 @@ +use conduit::Result; +use ruma::{api::appservice::Registration, UserId}; + +use super::NamespaceRegex; + +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, +} + +impl RegistrationInfo { + #[must_use] + pub fn is_user_match(&self, user_id: &UserId) -> bool { + self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } + + #[inline] + #[must_use] + pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { + self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() + } +} + +impl TryFrom<Registration> for RegistrationInfo { + type Error = regex::Error; + + fn try_from(value: Registration) -> Result<Self, regex::Error> { + Ok(Self { + users: value.namespaces.users.clone().try_into()?, + aliases: value.namespaces.aliases.clone().try_into()?, + rooms: value.namespaces.rooms.clone().try_into()?, + registration: value, + }) + } +} diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index b21f9dab5cf53ac93797e78b887c34b1be2c4bb1..f9a89e99ddd2c5acf40419ebe143adb9f9cf7d38 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -11,6 +11,7 @@ pub struct Service { pub extern_media: reqwest::Client, pub well_known: reqwest::Client, pub federation: reqwest::Client, + pub synapse: reqwest::Client, pub sender: reqwest::Client, pub appservice: reqwest::Client, pub pusher: reqwest::Client, @@ -48,12 +49,18 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { federation: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.federation_timeout)) - .timeout(Duration::from_secs(config.federation_timeout)) .pool_max_idle_per_host(config.federation_idle_per_host.into()) .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) .redirect(redirect::Policy::limited(3)) .build()?, + synapse: base(config)? + .dns_resolver(resolver.resolver.hooked.clone()) + .read_timeout(Duration::from_secs(305)) + .pool_max_idle_per_host(0) + .redirect(redirect::Policy::limited(3)) + .build()?, + sender: base(config)? .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.sender_timeout)) diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs index 1bb0843d42258218017452dce5d9ceb91bc509c9..c99a0891e445278ee83980cb1cfad11d37b71b43 100644 --- a/src/service/emergency/mod.rs +++ b/src/service/emergency/mod.rs @@ -32,7 +32,12 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { } async fn worker(self: Arc<Self>) -> Result<()> { + if self.services.globals.is_read_only() { + return Ok(()); + } + self.set_emergency_access() + .await .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; Ok(()) @@ -44,7 +49,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Sets the emergency password and push rules for the @conduit account in /// case emergency password is set - fn set_emergency_access(&self) -> Result<bool> { + async fn set_emergency_access(&self) -> Result<bool> { let conduit_user = &self.services.globals.server_user; self.services @@ -56,17 +61,20 @@ fn set_emergency_access(&self) -> Result<bool> { None => (Ruleset::new(), false), }; - self.services.account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; + self.services + .account_data + .update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + ) + .await?; if pwd_set { warn!( @@ -75,7 +83,7 @@ fn set_emergency_access(&self) -> Result<bool> { ); } else { // logs out any users still in the server service account and removes sessions - self.services.users.deactivate_account(conduit_user)?; + self.services.users.deactivate_account(conduit_user).await?; } Ok(pwd_set) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5b5d9f09df63b8edbb6fb47a49d7ec14e7d7575b..f715e944a6e42c6fd14ff66db116ffe611b11aeb 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,43 +1,12 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +use std::sync::{Arc, RwLock}; -use conduit::{trace, utils, Error, Result, Server}; -use database::{Database, Map}; -use futures_util::{stream::FuturesUnordered, StreamExt}; -use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, -}; - -use crate::{rooms, Dep}; +use conduit::{utils, Result}; +use database::{Database, Deserialized, Map}; pub struct Data { global: Arc<Map>, - todeviceid_events: Arc<Map>, - userroomid_joined: Arc<Map>, - userroomid_invitestate: Arc<Map>, - userroomid_leftstate: Arc<Map>, - userroomid_notificationcount: Arc<Map>, - userroomid_highlightcount: Arc<Map>, - pduid_pdu: Arc<Map>, - keychangeid_userid: Arc<Map>, - roomusertype_roomuserdataid: Arc<Map>, - server_signingkeys: Arc<Map>, - readreceiptid_readreceipt: Arc<Map>, - userid_lastonetimekeyupdate: Arc<Map>, counter: RwLock<u64>, pub(super) db: Arc<Database>, - services: Services, -} - -struct Services { - server: Arc<Server>, - short: Dep<rooms::short::Service>, - state_cache: Dep<rooms::state_cache::Service>, - typing: Dep<rooms::typing::Service>, } const COUNTER: &[u8] = b"c"; @@ -47,26 +16,8 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; Self { global: db["global"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), - userroomid_highlightcount: db["userroomid_highlightcount"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - keychangeid_userid: db["keychangeid_userid"].clone(), - roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), - server_signingkeys: db["server_signingkeys"].clone(), - readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), db: args.db.clone(), - services: Services { - server: args.server.clone(), - short: args.depend::<rooms::short::Service>("rooms::short"), - state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), - typing: args.depend::<rooms::typing::Service>("rooms::typing"), - }, } } @@ -83,7 +34,7 @@ pub fn next_count(&self) -> Result<u64> { .checked_add(1) .expect("counter must not overflow u64"); - self.global.insert(COUNTER, &counter.to_be_bytes())?; + self.global.insert(COUNTER, counter.to_be_bytes()); Ok(*counter) } @@ -102,232 +53,27 @@ pub fn current_count(&self) -> u64 { fn stored_count(global: &Arc<Map>) -> Result<u64> { global - .get(COUNTER)? + .get_blocking(COUNTER) .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } - #[tracing::instrument(skip(self), level = "debug")] - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xFF); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xFF); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed their key - // TODO: only send for user they share a room with - futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); - - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); - - // Events for rooms we are in - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - let short_roomid = self - .services - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xFF); - - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push(Box::pin(async move { - let _result = self.services.typing.wait_for_update(&room_id).await; - })); - - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } - - let mut globaluserdata_prefix = vec![0xFF]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - - futures.push(Box::pin(async move { - while self.services.server.running() { - let _result = self.services.server.signal.subscribe().recv().await; - } - })); - - if !self.services.server.running() { - return Ok(()); - } - - // Wait until one of them finds something - trace!(futures = futures.len(), "watch started"); - futures.next().await; - trace!(futures = futures.len(), "watch finished"); - - Ok(()) - } - - pub fn load_keypair(&self) -> Result<Ed25519KeyPair> { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - |val| Ok(val.to_vec()), - )?; - - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); - - utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }) - } - - #[inline] - pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } - - /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored - /// in room version > 4 - /// - /// Remove the outdated keys and insert the new ones. - /// - /// This doesn't actually check that the keys provided are newer than the - /// old set. - pub fn add_signing_key( - &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; - - keys.verify_keys.extend(verify_keys); - keys.old_verify_keys.extend(old_verify_keys); - - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; - - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - - Ok(tree) - } - - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { - let signingkeys = self - .signing_keys_for(origin)? - .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - tree - }); - - Ok(signingkeys) - } - - pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()); - - Ok(signingkeys) - } - - pub fn database_version(&self) -> Result<u64> { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) - }) + pub async fn database_version(&self) -> u64 { + self.global + .get(b"version") + .await + .deserialized() + .unwrap_or(0) } #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; + self.global.raw_put(b"version", new_version); Ok(()) } #[inline] - pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { self.db.db.backup() } + pub fn backup(&self) -> Result { self.db.db.backup() } #[inline] pub fn backup_list(&self) -> Result<String> { self.db.db.backup_list() } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs deleted file mode 100644 index 66917520b8f17a7de36db08404ae438b7b676727..0000000000000000000000000000000000000000 --- a/src/service/globals/migrations.rs +++ /dev/null @@ -1,864 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - fs::{self}, - io::Write, - mem::size_of, - sync::Arc, -}; - -use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Error, Result}; -use itertools::Itertools; -use ruma::{ - events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, - push::Ruleset, - EventId, OwnedRoomId, RoomId, UserId, -}; - -use crate::{media, Services}; - -/// The current schema version. -/// - If database is opened at greater version we reject with error. The -/// software must be updated for backward-incompatible changes. -/// - If database is opened at lesser version we apply migrations up to this. -/// Note that named-feature migrations may also be performed when opening at -/// equal or lesser version. These are expected to be backward-compatible. -pub(crate) const DATABASE_VERSION: u64 = 13; - -/// Conduit's database version. -/// -/// Conduit bumped the database version to 16, but did not introduce any -/// breaking changes. Their database migrations are extremely fragile and risky, -/// and also do not really apply to us, so just to retain Conduit -> conduwuit -/// compatibility we'll check for both versions. -pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16; - -pub(crate) async fn migrations(services: &Services) -> Result<()> { - // Matrix resource ownership is based on the server name; changing it - // requires recreating the database from scratch. - if services.users.count()? > 0 { - let conduit_user = &services.globals.server_user; - - if !services.users.exists(conduit_user)? { - error!("The {} server user does not exist, and the database is not new.", conduit_user); - return Err(Error::bad_database( - "Cannot reuse an existing database after changing the server name, please delete the old one first.", - )); - } - } - - if services.users.count()? > 0 { - migrate(services).await - } else { - fresh(services).await - } -} - -async fn fresh(services: &Services) -> Result<()> { - let db = &services.db; - let config = &services.server.config; - - services - .globals - .db - .bump_database_version(DATABASE_VERSION)?; - - db["global"].insert(b"feat_sha256_media", &[])?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; - - // Create the admin room and server user on first run - crate::admin::create_admin_room(services).await?; - - warn!( - "Created new {} database with version {DATABASE_VERSION}", - config.database_backend, - ); - - Ok(()) -} - -/// Apply any migrations -async fn migrate(services: &Services) -> Result<()> { - let db = &services.db; - let config = &services.server.config; - - if services.globals.db.database_version()? < 1 { - db_lt_1(services).await?; - } - - if services.globals.db.database_version()? < 2 { - db_lt_2(services).await?; - } - - if services.globals.db.database_version()? < 3 { - db_lt_3(services).await?; - } - - if services.globals.db.database_version()? < 4 { - db_lt_4(services).await?; - } - - if services.globals.db.database_version()? < 5 { - db_lt_5(services).await?; - } - - if services.globals.db.database_version()? < 6 { - db_lt_6(services).await?; - } - - if services.globals.db.database_version()? < 7 { - db_lt_7(services).await?; - } - - if services.globals.db.database_version()? < 8 { - db_lt_8(services).await?; - } - - if services.globals.db.database_version()? < 9 { - db_lt_9(services).await?; - } - - if services.globals.db.database_version()? < 10 { - db_lt_10(services).await?; - } - - if services.globals.db.database_version()? < 11 { - db_lt_11(services).await?; - } - - if services.globals.db.database_version()? < 12 { - db_lt_12(services).await?; - } - - // This migration can be reused as-is anytime the server-default rules are - // updated. - if services.globals.db.database_version()? < 13 { - db_lt_13(services).await?; - } - - if db["global"].get(b"feat_sha256_media")?.is_none() { - media::migrations::migrate_sha256_media(services).await?; - } else if config.media_startup_check { - media::migrations::checkup_sha256_media(services).await?; - } - - if db["global"] - .get(b"fix_bad_double_separator_in_state_cache")? - .is_none() - { - fix_bad_double_separator_in_state_cache(services).await?; - } - - if db["global"] - .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? - .is_none() - { - retroactively_fix_bad_data_from_roomuserid_joined(services).await?; - } - - let version_match = services.globals.db.database_version().unwrap() == DATABASE_VERSION - || services.globals.db.database_version().unwrap() == CONDUIT_DATABASE_VERSION; - - assert!( - version_match, - "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services.globals.db.database_version().unwrap(), - DATABASE_VERSION, - ); - - { - let patterns = services.globals.forbidden_usernames(); - if !patterns.is_empty() { - for user_id in services - .users - .iter() - .filter_map(Result::ok) - .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) - .filter(|user| user.server_name() == config.server_name) - { - let matches = patterns.matches(user_id.localpart()); - if matches.matched_any() { - warn!( - "User {} matches the following forbidden username patterns: {}", - user_id.to_string(), - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - - { - let patterns = services.globals.forbidden_alias_names(); - if !patterns.is_empty() { - for address in services.rooms.metadata.iter_ids() { - let room_id = address?; - let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); - for room_alias_result in room_aliases { - let room_alias = room_alias_result?; - let matches = patterns.matches(room_alias.alias()); - if matches.matched_any() { - warn!( - "Room with alias {} ({}) matches the following forbidden room name patterns: {}", - room_alias, - &room_id, - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - } - - info!( - "Loaded {} database with schema version {DATABASE_VERSION}", - config.database_backend, - ); - - Ok(()) -} - -async fn db_lt_1(services: &Services) -> Result<()> { - let db = &services.db; - - let roomserverids = &db["roomserverids"]; - let serverroomids = &db["serverroomids"]; - for (roomserverid, _) in roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xFF); - let room_id = parts.next().expect("split always returns one element"); - let Some(servername) = parts.next() else { - error!("Migration: Invalid roomserverid in db."); - continue; - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xFF); - serverroomid.extend_from_slice(room_id); - - serverroomids.insert(&serverroomid, &[])?; - } - - services.globals.db.bump_database_version(1)?; - info!("Migration: 0 -> 1 finished"); - Ok(()) -} - -async fn db_lt_2(services: &Services) -> Result<()> { - let db = &services.db; - - // We accidentally inserted hashed versions of "" into the db instead of just "" - let userid_password = &db["roomserverids"]; - for (userid, password) in userid_password.iter() { - let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); - let password = std::str::from_utf8(&password).expect("password is valid utf-8"); - let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); - if empty_hashed_password { - userid_password.insert(&userid, b"")?; - } - } - - services.globals.db.bump_database_version(2)?; - info!("Migration: 1 -> 2 finished"); - Ok(()) -} - -async fn db_lt_3(services: &Services) -> Result<()> { - let db = &services.db; - - // Move media to filesystem - let mediaid_file = &db["mediaid_file"]; - for (key, content) in mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services.media.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - mediaid_file.insert(&key, &[])?; - } - - services.globals.db.bump_database_version(3)?; - info!("Migration: 2 -> 3 finished"); - Ok(()) -} - -async fn db_lt_4(services: &Services) -> Result<()> { - let config = &services.server.config; - - // Add federated users to services as deactivated - for our_user in services.users.iter() { - let our_user = our_user?; - if services.users.is_deactivated(&our_user)? { - continue; - } - for room in services.rooms.state_cache.rooms_joined(&our_user) { - for user in services.rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != config.server_name { - info!(?user, "Migration: creating user"); - services.users.create(&user, None)?; - } - } - } - } - - services.globals.db.bump_database_version(4)?; - info!("Migration: 3 -> 4 finished"); - Ok(()) -} - -async fn db_lt_5(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade user data store - let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; - let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; - for (roomuserdataid, _) in roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id); - key.push(0xFF); - key.extend_from_slice(event_type); - - roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; - } - - services.globals.db.bump_database_version(5)?; - info!("Migration: 4 -> 5 finished"); - Ok(()) -} - -async fn db_lt_6(services: &Services) -> Result<()> { - let db = &services.db; - - // Set room member count - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - for (roomid, _) in roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services.rooms.state_cache.update_joined_count(room_id)?; - } - - services.globals.db.bump_database_version(6)?; - info!("Migration: 5 -> 6 finished"); - Ok(()) -} - -async fn db_lt_7(services: &Services) -> Result<()> { - let db = &services.db; - - // Upgrade state store - let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); - let mut current_sstatehash: Option<u64> = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - - let handle_state = |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { - services - .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::<HashSet<_>>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::<HashSet<_>>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services.rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>() - ); - */ - - Ok::<_, Error>(()) - }; - - let stateid_shorteventid = &db["stateid_shorteventid"]; - let shorteventid_eventid = &db["shorteventid_eventid"]; - for (k, seventid) in stateid_shorteventid.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct"); - let sstatekey = k[size_of::<u64>()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); - } - } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services.globals.db.bump_database_version(7)?; - info!("Migration: 6 -> 7 finished"); - Ok(()) -} - -async fn db_lt_8(services: &Services) -> Result<()> { - let db = &services.db; - - let roomid_shortstatehash = &db["roomid_shortstatehash"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - let pduid_pdu = &db["pduid_pdu"]; - let eventid_pduid = &db["eventid_pduid"]; - - // Generate short room ids for all rooms - for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services.globals.next_count()?.to_be_bytes(); - roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let batch = pduid_pdu - .iter() - .filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(count); - - Some(database::OwnedKeyVal(new_key, v)) - }) - .collect::<Vec<_>>(); - - pduid_pdu.insert_batch(batch.iter().map(database::KeyVal::from))?; - - let batch2 = eventid_pduid - .iter() - .filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id.to_vec(); - new_value.extend_from_slice(count); - - Some(database::OwnedKeyVal(k, new_value)) - }) - .collect::<Vec<_>>(); - - eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - - services.globals.db.bump_database_version(8)?; - info!("Migration: 7 -> 8 finished"); - Ok(()) -} - -async fn db_lt_9(services: &Services) -> Result<()> { - let db = &services.db; - - let tokenids = &db["tokenids"]; - let roomid_shortroomid = &db["roomid_shortroomid"]; - - // Update tokenids db layout - let mut iter = tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id.to_vec(); - new_key.extend_from_slice(word); - new_key.push(0xFF); - new_key.extend_from_slice(pdu_id_count); - Some(database::OwnedKeyVal(new_key, Vec::<u8>::new())) - }) - .peekable(); - - while iter.peek().is_some() { - let batch = iter.by_ref().take(1000).collect::<Vec<_>>(); - tokenids.insert_batch(batch.iter().map(database::KeyVal::from))?; - debug!("Inserted smaller batch"); - } - - info!("Deleting starts"); - - let batch2: Vec<_> = tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - tokenids.remove(&key)?; - } - - services.globals.db.bump_database_version(9)?; - info!("Migration: 8 -> 9 finished"); - Ok(()) -} - -async fn db_lt_10(services: &Services) -> Result<()> { - let db = &services.db; - - let statekey_shortstatekey = &db["statekey_shortstatekey"]; - let shortstatekey_statekey = &db["shortstatekey_statekey"]; - - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in statekey_shortstatekey.iter() { - shortstatekey_statekey.insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services.users.iter().filter_map(Result::ok) { - services.users.mark_device_key_update(&user_id)?; - } - - services.globals.db.bump_database_version(10)?; - info!("Migration: 9 -> 10 finished"); - Ok(()) -} - -#[allow(unreachable_code)] -async fn db_lt_11(services: &Services) -> Result<()> { - error!("Dropping a column to clear data is not implemented yet."); - //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; - //userdevicesessionid_uiaarequest.clear()?; - - services.globals.db.bump_database_version(11)?; - info!("Migration: 10 -> 11 finished"); - Ok(()) -} - -async fn db_lt_12(services: &Services) -> Result<()> { - let config = &services.server.config; - - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - }, - }; - - let raw_rules_list = services - .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); - let rules_list = &mut account_data.content.global; - - //content rule - { - let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; - - let rule = rules_list.content.get(content_rule_transformation[0]); - if rule.is_some() { - let mut rule = rule.unwrap().clone(); - content_rule_transformation[1].clone_into(&mut rule.rule_id); - rules_list - .content - .shift_remove(content_rule_transformation[0]); - rules_list.content.insert(rule); - } - } - - //underride rules - { - let underride_rule_transformation = [ - [".m.rules.call", ".m.rule.call"], - [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], - [".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"], - [".m.rules.message", ".m.rule.message"], - [".m.rules.encrypted", ".m.rule.encrypted"], - ]; - - for transformation in underride_rule_transformation { - let rule = rules_list.underride.get(transformation[0]); - if let Some(rule) = rule { - let mut rule = rule.clone(); - transformation[1].clone_into(&mut rule.rule_id); - rules_list.underride.shift_remove(transformation[0]); - rules_list.underride.insert(rule); - } - } - } - - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services.globals.db.bump_database_version(12)?; - info!("Migration: 11 -> 12 finished"); - Ok(()) -} - -async fn db_lt_13(services: &Services) -> Result<()> { - let config = &services.server.config; - - for username in services.users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - }, - }; - - let raw_rules_list = services - .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); - - let user_default_rules = Ruleset::server_default(&user); - account_data - .content - .global - .update_with_server_default(user_default_rules); - - services.account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services.globals.db.bump_database_version(13)?; - info!("Migration: 12 -> 13 finished"); - Ok(()) -} - -async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> { - warn!("Fixing bad double separator in state_cache roomuserid_joined"); - - let db = &services.db; - let roomuserid_joined = &db["roomuserid_joined"]; - let _cork = db.cork_and_sync(); - - let mut iter_count: usize = 0; - for (mut key, value) in roomuserid_joined.iter() { - iter_count = iter_count.saturating_add(1); - debug_info!(%iter_count); - let first_sep_index = key - .iter() - .position(|&i| i == 0xFF) - .expect("found 0xFF delim"); - - if key - .iter() - .get(first_sep_index..=first_sep_index.saturating_add(1)) - .copied() - .collect_vec() - == vec![0xFF, 0xFF] - { - debug_warn!("Found bad key: {key:?}"); - roomuserid_joined.remove(&key)?; - - key.remove(first_sep_index); - debug_warn!("Fixed key: {key:?}"); - roomuserid_joined.insert(&key, &value)?; - } - } - - db.db.cleanup()?; - db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; - - info!("Finished fixing"); - Ok(()) -} - -async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> { - warn!("Retroactively fixing bad data from broken roomuserid_joined"); - - let db = &services.db; - let _cork = db.cork_and_sync(); - - let room_ids = services - .rooms - .metadata - .iter_ids() - .filter_map(Result::ok) - .collect_vec(); - - for room_id in room_ids.clone() { - debug_info!("Fixing room {room_id}"); - - let users_in_room = services - .rooms - .state_cache - .room_members(&room_id) - .filter_map(Result::ok) - .collect_vec(); - - let joined_members = users_in_room - .iter() - .filter(|user_id| { - services - .rooms - .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| membership.membership == MembershipState::Join) - }) - .collect_vec(); - - let non_joined_members = users_in_room - .iter() - .filter(|user_id| { - services - .rooms - .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| { - membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban - }) - }) - .collect_vec(); - - for user_id in joined_members { - debug_info!("User is joined, marking as joined"); - services - .rooms - .state_cache - .mark_as_joined(user_id, &room_id)?; - } - - for user_id in non_joined_members { - debug_info!("User is left or banned, marking as left"); - services.rooms.state_cache.mark_as_left(user_id, &room_id)?; - } - } - - for room_id in room_ids { - debug_info!( - "Updating joined count for room {room_id} to fix servers in room after correcting membership states" - ); - - services.rooms.state_cache.update_joined_count(&room_id)?; - } - - db.db.cleanup()?; - db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; - - info!("Finished fixing"); - Ok(()) -} diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 87f8f4925829a664ba5f919ef1e82586935f08a2..55dd10aabcd82d5f93e0a60567bab827d1d3754c 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,8 +1,7 @@ mod data; -pub(super) mod migrations; use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, fmt::Write, sync::{Arc, RwLock}, time::Instant, @@ -13,13 +12,8 @@ use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ - api::{ - client::discovery::discover_support::ContactRole, - federation::discovery::{ServerSigningKeys, VerifyKey}, - }, - serde::Base64, - DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, - RoomVersionId, ServerName, UserId, + api::client::discovery::discover_support::ContactRole, OwnedEventId, OwnedRoomAliasId, OwnedServerName, + OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, }; use tokio::sync::Mutex; use url::Url; @@ -31,7 +25,6 @@ pub struct Service { pub config: Config, pub cidr_range_denylist: Vec<IPAddress>, - keypair: Arc<ruma::signatures::Ed25519KeyPair>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey>, pub stable_room_versions: Vec<RoomVersionId>, pub unstable_room_versions: Vec<RoomVersionId>, @@ -41,6 +34,7 @@ pub struct Service { pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, pub turn_secret: String, + pub registration_token: Option<String>, } type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -49,16 +43,6 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let db = Data::new(&args); let config = &args.server.config; - let keypair = db.load_keypair(); - - let keypair = match keypair { - Ok(k) => k, - Err(e) => { - error!("Keypair invalid. Deleting..."); - db.remove_keypair()?; - return Err(e); - }, - }; let jwt_decoding_key = config .jwt_secret @@ -96,11 +80,24 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { }) }); + let registration_token = + config + .registration_token_file + .as_ref() + .map_or(config.registration_token.clone(), |path| { + let Ok(token) = std::fs::read_to_string(path).inspect_err(|e| { + error!("Failed to read the registration token file: {e}"); + }) else { + return config.registration_token.clone(); + }; + + Some(token) + }); + let mut s = Self { db, config: config.clone(), cidr_range_denylist, - keypair: Arc::new(keypair), jwt_decoding_key, stable_room_versions, unstable_room_versions, @@ -112,6 +109,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { server_user: UserId::parse_with_server_name(String::from("conduit"), &config.server_name) .expect("@conduit:server_name is valid"), turn_secret, + registration_token, }; if !s @@ -159,24 +157,15 @@ fn name(&self) -> &str { service::make_name(std::module_path!()) } } impl Service { - /// Returns this server's keypair. - pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } - #[inline] pub fn next_count(&self) -> Result<u64> { self.db.next_count() } #[inline] pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.watch(user_id, device_id).await - } - #[inline] pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } - pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } - pub fn allow_registration(&self) -> bool { self.config.allow_registration } pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration } @@ -208,8 +197,6 @@ pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_upda pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } - pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } pub fn turn_password(&self) -> &String { &self.config.turn_password } @@ -220,10 +207,6 @@ pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris } pub fn turn_username(&self) -> &String { &self.config.turn_username } - pub fn allow_profile_lookup_federation_requests(&self) -> bool { - self.config.allow_profile_lookup_federation_requests - } - pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path } pub fn emergency_password(&self) -> &Option<String> { &self.config.emergency_password } @@ -260,10 +243,6 @@ pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_ pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts } - pub fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] { - &self.config.forbidden_remote_room_directory_server_names - } - pub fn well_known_support_page(&self) -> &Option<Url> { &self.config.well_known.support_page } pub fn well_known_support_role(&self) -> &Option<ContactRole> { &self.config.well_known.support_role } @@ -286,28 +265,6 @@ pub fn supported_room_versions(&self) -> Vec<RoomVersionId> { } } - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { - let mut keys = self.db.verify_keys_for(origin)?; - if origin == self.server_name() { - keys.insert( - format!("ed25519:{}", self.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(self.keypair.public_key().to_vec()), - }, - ); - } - - Ok(keys) - } - - pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> { - self.db.signing_keys_for(origin) - } - pub fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client } pub fn well_known_server(&self) -> &Option<OwnedServerName> { &self.config.well_known.server } @@ -329,4 +286,7 @@ pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user #[inline] pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name } + + #[inline] + pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() } } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs deleted file mode 100644 index 30ac593b1e48777898c4ed53c525b45ab8c8145d..0000000000000000000000000000000000000000 --- a/src/service/key_backups/data.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::{collections::BTreeMap, sync::Arc}; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - backupid_algorithm: Arc<Map>, - backupid_etag: Arc<Map>, - backupkeyid_backup: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - backupid_algorithm: db["backupid_algorithm"].clone(), - backupid_etag: db["backupid_etag"].clone(), - backupkeyid_backup: db["backupkeyid_backup"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - }, - } - } - - pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { - let version = self.services.globals.next_count()?.to_string(); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version) - } - - pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; - - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>, - ) -> Result<String> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } - - pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } - - pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - - Ok(( - version, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, - )) - }) - .transpose() - } - - pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } - - pub(super) fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); - } - - self.backupid_etag - .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; - - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; - - Ok(()) - } - - pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } - - pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } - - pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - - let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); - - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((room_id, session_id, key_data)) - }) { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } - - Ok(rooms) - } - - pub(super) fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); - - let session_id = utils::string_from_bytes( - parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(Result::ok) - .collect()) - } - - pub(super) fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<Option<Raw<KeyBackupData>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) - }) - .transpose() - } - - pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } - - pub(super) fn delete_room_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(version.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(session_id.as_bytes()); - - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } - - Ok(()) - } -} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 65d3c065ed7f44ea4cf4c3a80d39c79af26a4b0e..bae6f214441f95d5dcf00981591c41e9cea91d6c 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,93 +1,261 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::Result; -use data::Data; +use conduit::{ + err, implement, + utils::stream::{ReadyExt, TryIgnore}, + Err, Result, +}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; +use futures::StreamExt; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; +use crate::{globals, Dep}; + pub struct Service { db: Data, + services: Services, +} + +struct Data { + backupid_algorithm: Arc<Map>, + backupid_etag: Arc<Map>, + backupkeyid_backup: Arc<Map>, +} + +struct Services { + globals: Dep<globals::Service>, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + backupid_algorithm: args.db["backupid_algorithm"].clone(), + backupid_etag: args.db["backupid_etag"].clone(), + backupkeyid_backup: args.db["backupkeyid_backup"].clone(), + }, + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { - self.db.create_backup(user_id, backup_metadata) - } +#[implement(Service)] +pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { + let version = self.services.globals.next_count()?.to_string(); + let count = self.services.globals.next_count()?; - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_backup(user_id, version) - } + let key = (user_id, &version); + self.db.backupid_algorithm.put(key, Json(backup_metadata)); - pub fn update_backup( - &self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>, - ) -> Result<String> { - self.db.update_backup(user_id, version, backup_metadata) - } + self.db.backupid_etag.put(key, count); - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { - self.db.get_latest_backup_version(user_id) - } + Ok(version) +} - pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { - self.db.get_latest_backup(user_id) - } +#[implement(Service)] +pub async fn delete_backup(&self, user_id: &UserId, version: &str) { + let key = (user_id, version); + self.db.backupid_algorithm.del(key); + self.db.backupid_etag.del(key); - pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { - self.db.get_backup(user_id, version) - } + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_prefix_raw(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} - pub fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, - ) -> Result<()> { - self.db - .add_key(user_id, version, room_id, session_id, key_data) +#[implement(Service)] +pub async fn update_backup<'a>( + &self, user_id: &UserId, version: &'a str, backup_metadata: &Raw<BackupAlgorithm>, +) -> Result<&'a str> { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) } + let count = self.services.globals.next_count().unwrap(); + self.db.backupid_etag.put(key, count); + self.db + .backupid_algorithm + .put_raw(key, backup_metadata.json().get()); - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) } + Ok(version) +} - pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { - self.db.get_all(user_id, version) - } +#[implement(Service)] +pub async fn get_latest_backup_version(&self, user_id: &UserId) -> Result<String> { + type Key<'a> = (&'a UserId, &'a str); - pub fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { - self.db.get_room(user_id, version, room_id) - } + let last_possible_key = (user_id, u64::MAX); + self.db + .backupid_algorithm + .rev_keys_from(&last_possible_key) + .ignore_err() + .ready_take_while(|(user_id_, _): &Key<'_>| *user_id_ == user_id) + .map(|(_, version): Key<'_>| version.to_owned()) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup versions found")))) +} - pub fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result<Option<Raw<KeyBackupData>>> { - self.db.get_session(user_id, version, room_id, session_id) - } +#[implement(Service)] +pub async fn get_latest_backup(&self, user_id: &UserId) -> Result<(String, Raw<BackupAlgorithm>)> { + type Key<'a> = (&'a UserId, &'a str); + type KeyVal<'a> = (Key<'a>, Raw<BackupAlgorithm>); - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_all_keys(user_id, version) - } + let last_possible_key = (user_id, u64::MAX); + self.db + .backupid_algorithm + .rev_stream_from(&last_possible_key) + .ignore_err() + .ready_take_while(|((user_id_, _), _): &KeyVal<'_>| *user_id_ == user_id) + .map(|((_, version), algorithm): KeyVal<'_>| (version.to_owned(), algorithm)) + .next() + .await + .ok_or_else(|| err!(Request(NotFound("No backup found")))) +} - pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - self.db.delete_room_keys(user_id, version, room_id) - } +#[implement(Service)] +pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> { + let key = (user_id, version); + self.db.backupid_algorithm.qry(&key).await.deserialized() +} - pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - self.db - .delete_room_key(user_id, version, room_id, session_id) +#[implement(Service)] +pub async fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, +) -> Result<()> { + let key = (user_id, version); + if self.db.backupid_algorithm.qry(&key).await.is_err() { + return Err!(Request(NotFound("Tried to update nonexistent backup."))); } + + let count = self.services.globals.next_count().unwrap(); + self.db.backupid_etag.put(key, count); + + let key = (user_id, version, room_id, session_id); + self.db + .backupkeyid_backup + .put_raw(key, key_data.json().get()); + + Ok(()) +} + +#[implement(Service)] +pub async fn count_keys(&self, user_id: &UserId, version: &str) -> usize { + let prefix = (user_id, version); + self.db + .backupkeyid_backup + .keys_prefix_raw(&prefix) + .count() + .await +} + +#[implement(Service)] +pub async fn get_etag(&self, user_id: &UserId, version: &str) -> String { + let key = (user_id, version); + self.db + .backupid_etag + .qry(&key) + .await + .deserialized::<u64>() + .as_ref() + .map(ToString::to_string) + .expect("Backup has no etag.") +} + +#[implement(Service)] +pub async fn get_all(&self, user_id: &UserId, version: &str) -> BTreeMap<OwnedRoomId, RoomKeyBackup> { + type Key<'a> = (Ignore, Ignore, &'a RoomId, &'a str); + type KeyVal<'a> = (Key<'a>, Raw<KeyBackupData>); + + let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); + let default = || RoomKeyBackup { + sessions: BTreeMap::new(), + }; + + let prefix = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .ready_for_each(|((_, _, room_id, session_id), key_backup_data): KeyVal<'_>| { + rooms + .entry(room_id.into()) + .or_insert_with(default) + .sessions + .insert(session_id.into(), key_backup_data); + }) + .await; + + rooms +} + +#[implement(Service)] +pub async fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, +) -> BTreeMap<String, Raw<KeyBackupData>> { + type KeyVal<'a> = ((Ignore, Ignore, Ignore, &'a str), Raw<KeyBackupData>); + + let prefix = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .stream_prefix(&prefix) + .ignore_err() + .map(|((.., session_id), key_backup_data): KeyVal<'_>| (session_id.to_owned(), key_backup_data)) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, +) -> Result<Raw<KeyBackupData>> { + let key = (user_id, version, room_id, session_id); + + self.db.backupkeyid_backup.qry(&key).await.deserialized() +} + +#[implement(Service)] +pub async fn delete_all_keys(&self, user_id: &UserId, version: &str) { + let key = (user_id, version, Interfix); + self.db + .backupkeyid_backup + .keys_prefix_raw(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) { + let key = (user_id, version, room_id, Interfix); + self.db + .backupkeyid_backup + .keys_prefix_raw(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; +} + +#[implement(Service)] +pub async fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) { + let key = (user_id, version, room_id, session_id); + self.db + .backupkeyid_backup + .keys_prefix_raw(&key) + .ignore_err() + .ready_for_each(|outdated_key| self.db.backupkeyid_backup.remove(outdated_key)) + .await; } diff --git a/src/service/manager.rs b/src/service/manager.rs index 42260bb30820ab6635529e82051a6c1289730bcb..21e0ed7c24ffadbb6b5549614cf913e74ad9ebec 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -1,7 +1,7 @@ use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; -use futures_util::FutureExt; +use futures::FutureExt; use tokio::{ sync::{Mutex, MutexGuard}, task::{JoinHandle, JoinSet}, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5d6d20b19f0d004c39d381a8f560fdd7d650ea0..9afbd708fe1a55bae000afe136f043721a066304 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,12 +1,13 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use conduit::{ - debug, debug_info, trace, - utils::{str_from_bytes, string_from_bytes}, + debug, debug_info, err, + utils::{str_from_bytes, stream::TryIgnore, string_from_bytes, ReadyExt}, Err, Error, Result, }; -use database::{Database, Map}; -use ruma::{api::client::error::ErrorKind, http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; +use database::{Database, Interfix, Map}; +use futures::StreamExt; +use ruma::{http_headers::ContentDisposition, Mxc, OwnedMxcUri, UserId}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -36,88 +37,54 @@ pub(super) fn create_file_metadata( &self, mxc: &Mxc<'_>, user: Option<&UserId>, dim: &Dim, content_disposition: Option<&ContentDisposition>, content_type: Option<&str>, ) -> Result<Vec<u8>> { - let mut key: Vec<u8> = Vec::new(); - key.extend_from_slice(b"mxc://"); - key.extend_from_slice(mxc.server_name.as_bytes()); - key.extend_from_slice(b"/"); - key.extend_from_slice(mxc.media_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&dim.width.to_be_bytes()); - key.extend_from_slice(&dim.height.to_be_bytes()); - key.push(0xFF); - key.extend_from_slice( - content_disposition - .map(ToString::to_string) - .unwrap_or_default() - .as_bytes(), - ); - key.push(0xFF); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - self.mediaid_file.insert(&key, &[])?; - + let dim: &[u32] = &[dim.width, dim.height]; + let key = (mxc, dim, content_disposition, content_type); + let key = database::serialize_to_vec(key)?; + self.mediaid_file.insert(&key, []); if let Some(user) = user { - let mut key: Vec<u8> = Vec::new(); - key.extend_from_slice(b"mxc://"); - key.extend_from_slice(mxc.server_name.as_bytes()); - key.extend_from_slice(b"/"); - key.extend_from_slice(mxc.media_id.as_bytes()); - let user = user.as_bytes().to_vec(); - self.mediaid_user.insert(&key, &user)?; + let key = (mxc, user); + self.mediaid_user.put_raw(key, user); } Ok(key) } - pub(super) fn delete_file_mxc(&self, mxc: &Mxc<'_>) -> Result<()> { + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); - let mut prefix: Vec<u8> = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - - trace!("MXC db prefix: {prefix:?}"); - for (key, _) in self.mediaid_file.scan_prefix(prefix.clone()) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } + let prefix = (mxc, Interfix); + self.mediaid_file + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.mediaid_file.remove(key)) + .await; - for (key, value) in self.mediaid_user.scan_prefix(prefix.clone()) { - if key.starts_with(&prefix) { - let user = str_from_bytes(&value).unwrap_or_default(); + self.mediaid_user + .stream_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|(key, val)| { + debug_assert!(key.starts_with(mxc.to_string().as_bytes()), "key should start with the mxc"); - debug_info!("Deleting key \"{key:?}\" which was uploaded by user {user}"); - self.mediaid_user.remove(&key)?; - } - } + let user = str_from_bytes(val).unwrap_or_default(); + debug_info!("Deleting key {key:?} which was uploaded by user {user}"); - Ok(()) + self.mediaid_user.remove(key); + }) + .await; } /// Searches for all files with the given MXC - pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result<Vec<Vec<u8>>> { + pub(super) async fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result<Vec<Vec<u8>>> { debug!("MXC URI: {mxc}"); - let mut prefix: Vec<u8> = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - + let prefix = (mxc, Interfix); let keys: Vec<Vec<u8>> = self .mediaid_file - .scan_prefix(prefix) - .map(|(key, _)| key) - .collect(); + .keys_prefix_raw(&prefix) + .ignore_err() + .map(<[u8]>::to_vec) + .collect() + .await; if keys.is_empty() { return Err!(Database("Failed to find any keys in database for `{mxc}`",)); @@ -128,22 +95,18 @@ pub(super) fn search_mxc_metadata_prefix(&self, mxc: &Mxc<'_>) -> Result<Vec<Vec Ok(keys) } - pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> { - let mut prefix: Vec<u8> = Vec::new(); - prefix.extend_from_slice(b"mxc://"); - prefix.extend_from_slice(mxc.server_name.as_bytes()); - prefix.extend_from_slice(b"/"); - prefix.extend_from_slice(mxc.media_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(&dim.width.to_be_bytes()); - prefix.extend_from_slice(&dim.height.to_be_bytes()); - prefix.push(0xFF); - - let (key, _) = self + pub(super) async fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Metadata> { + let dim: &[u32] = &[dim.width, dim.height]; + let prefix = (mxc, dim, Interfix); + + let key = self .mediaid_file - .scan_prefix(prefix) + .keys_prefix_raw(&prefix) + .ignore_err() + .map(ToOwned::to_owned) .next() - .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + .await + .ok_or_else(|| err!(Request(NotFound("Media not found"))))?; let mut parts = key.rsplit(|&b| b == 0xFF); @@ -177,32 +140,33 @@ pub(super) fn search_file_metadata(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Me } /// Gets all the MXCs associated with a user - pub(super) fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec<OwnedMxcUri> { - let user_id = user_id.as_bytes().to_vec(); - + pub(super) async fn get_all_user_mxcs(&self, user_id: &UserId) -> Vec<OwnedMxcUri> { self.mediaid_user - .iter() - .filter_map(|(key, user)| { - if *user == user_id { - let mxc_s = string_from_bytes(&key).ok()?; - Some(OwnedMxcUri::from(mxc_s)) - } else { - None - } - }) + .stream() + .ignore_err() + .ready_filter_map(|(key, user): (&str, &UserId)| (user == user_id).then(|| key.into())) .collect() + .await } /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) - pub(crate) fn get_all_media_keys(&self) -> Vec<Vec<u8>> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + pub(crate) async fn get_all_media_keys(&self) -> Vec<Vec<u8>> { + self.mediaid_file + .raw_keys() + .ignore_err() + .map(<[u8]>::to_vec) + .collect() + .await + } #[inline] - pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { + self.url_previews.remove(url.as_bytes()); + Ok(()) + } - pub(super) fn set_url_preview( - &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, - ) -> Result<()> { + pub(super) fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: Duration) -> Result<()> { let mut value = Vec::<u8>::new(); value.extend_from_slice(×tamp.as_secs().to_be_bytes()); value.push(0xFF); @@ -233,11 +197,13 @@ pub(super) fn set_url_preview( value.push(0xFF); value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - self.url_previews.insert(url.as_bytes(), &value) + self.url_previews.insert(url.as_bytes(), &value); + + Ok(()) } - pub(super) fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { - let values = self.url_previews.get(url.as_bytes()).ok()??; + pub(super) async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> { + let values = self.url_previews.get(url).await?; let mut values = values.split(|&b| b == 0xFF); @@ -291,7 +257,7 @@ pub(super) fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { x => x, }; - Some(UrlPreviewData { + Ok(UrlPreviewData { title, description, image, diff --git a/src/service/media/migrations.rs b/src/service/media/migrations.rs index 9968d25b78be2d9bb863478486b84d1767a2353a..f1c6da7d8eb0ac130f7002383c65317b732bb48e 100644 --- a/src/service/media/migrations.rs +++ b/src/service/media/migrations.rs @@ -7,9 +7,13 @@ time::Instant, }; -use conduit::{debug, debug_info, debug_warn, error, info, warn, Config, Result}; +use conduit::{ + debug, debug_info, debug_warn, error, info, + utils::{stream::TryIgnore, ReadyExt}, + warn, Config, Result, +}; -use crate::{globals, Services}; +use crate::{migrations, Services}; /// Migrates a media directory from legacy base64 file names to sha2 file names. /// All errors are fatal. Upon success the database is keyed to not perform this @@ -23,12 +27,17 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); - for (key, _) in mediaid_file.iter() { - let old = services.media.get_media_file_b64(&key); - let new = services.media.get_media_file_sha256(&key); - debug!(?key, ?old, ?new, num = changes.len(), "change"); - changes.push((old, new)); - } + mediaid_file + .raw_keys() + .ignore_err() + .ready_for_each(|key| { + let old = services.media.get_media_file_b64(key); + let new = services.media.get_media_file_sha256(key); + debug!(?key, ?old, ?new, num = changes.len(), "change"); + changes.push((old, new)); + }) + .await; + // move the file to the new location for (old_path, path) in changes { if old_path.exists() { @@ -41,11 +50,11 @@ pub(crate) async fn migrate_sha256_media(services: &Services) -> Result<()> { // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services.globals.db.database_version()? == 14 && globals::migrations::DATABASE_VERSION == 13 { + if services.globals.db.database_version().await == 14 && migrations::DATABASE_VERSION == 13 { services.globals.db.bump_database_version(13)?; } - db["global"].insert(b"feat_sha256_media", &[])?; + db["global"].insert(b"feat_sha256_media", []); info!("Finished applying sha256_media"); Ok(()) } @@ -71,7 +80,7 @@ pub(crate) async fn checkup_sha256_media(services: &Services) -> Result<()> { .filter_map(|ent| ent.map_or(None, |ent| Some(ent.path().into_os_string()))) .collect(); - for key in media.db.get_all_media_keys() { + for key in media.db.get_all_media_keys().await { let new_path = media.get_media_file_sha256(&key).into_os_string(); let old_path = media.get_media_file_b64(&key).into_os_string(); if let Err(e) = handle_media_check(&dbs, config, &files, &key, &new_path, &old_path).await { @@ -112,8 +121,8 @@ async fn handle_media_check( "Media is missing at all paths. Removing from database..." ); - mediaid_file.remove(key)?; - mediaid_user.remove(key)?; + mediaid_file.remove(key); + mediaid_user.remove(key); } if config.media_compat_file_link && !old_exists && new_exists { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index d3765a176948c30cc211822e2f8fd7e59cfb0007..c0b15726f22a3def6fff21c2a5bd739472b056ab 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -97,7 +97,7 @@ pub async fn create( /// Deletes a file in the database and from the media directory via an MXC pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { - if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc) { + if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc).await { for key in keys { trace!(?mxc, "MXC Key: {key:?}"); debug_info!(?mxc, "Deleting from filesystem"); @@ -107,7 +107,7 @@ pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { } debug_info!(?mxc, "Deleting from database"); - _ = self.db.delete_file_mxc(mxc); + self.db.delete_file_mxc(mxc).await; } Ok(()) @@ -120,7 +120,7 @@ pub async fn delete(&self, mxc: &Mxc<'_>) -> Result<()> { /// /// currently, this is only practical for local users pub async fn delete_from_user(&self, user: &UserId) -> Result<usize> { - let mxcs = self.db.get_all_user_mxcs(user); + let mxcs = self.db.get_all_user_mxcs(user).await; let mut deletion_count: usize = 0; for mxc in mxcs { @@ -150,7 +150,7 @@ pub async fn get(&self, mxc: &Mxc<'_>) -> Result<Option<FileMeta>> { content_disposition, content_type, key, - }) = self.db.search_file_metadata(mxc, &Dim::default()) + }) = self.db.search_file_metadata(mxc, &Dim::default()).await { let mut content = Vec::new(); let path = self.get_media_file(&key); @@ -170,7 +170,7 @@ pub async fn get(&self, mxc: &Mxc<'_>) -> Result<Option<FileMeta>> { /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut mxcs = Vec::with_capacity(all_keys.len()); @@ -209,7 +209,7 @@ pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> { pub async fn delete_all_remote_media_at_after_time( &self, time: SystemTime, before: bool, after: bool, yes_i_want_to_delete_local_media: bool, ) -> Result<usize> { - let all_keys = self.db.get_all_media_keys(); + let all_keys = self.db.get_all_media_keys().await; let mut remote_mxcs = Vec::with_capacity(all_keys.len()); for key in all_keys { @@ -343,9 +343,10 @@ async fn create_media_file(&self, key: &[u8]) -> Result<fs::File> { } #[inline] - pub fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<FileMeta> { + pub async fn get_metadata(&self, mxc: &Mxc<'_>) -> Option<FileMeta> { self.db .search_file_metadata(mxc, &Dim::default()) + .await .map(|metadata| FileMeta { content_disposition: metadata.content_disposition, content_type: metadata.content_type, diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs index 5704075e5b8afd3ec430c4d08dbdabf4c2ed828d..acc9d8ed1fca4f325a98b5b3b96b193838184913 100644 --- a/src/service/media/preview.rs +++ b/src/service/media/preview.rs @@ -1,6 +1,6 @@ use std::{io::Cursor, time::SystemTime}; -use conduit::{debug, utils, warn, Err, Result}; +use conduit::{debug, utils, Err, Result}; use conduit_core::implement; use image::ImageReader as ImgReader; use ipaddress::IPAddress; @@ -70,30 +70,30 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> { } #[implement(Service)] -pub async fn get_url_preview(&self, url: &str) -> Result<UrlPreviewData> { - if let Some(preview) = self.db.get_url_preview(url) { +pub async fn get_url_preview(&self, url: &Url) -> Result<UrlPreviewData> { + if let Ok(preview) = self.db.get_url_preview(url.as_str()).await { return Ok(preview); } // ensure that only one request is made per URL - let _request_lock = self.url_preview_mutex.lock(url).await; + let _request_lock = self.url_preview_mutex.lock(url.as_str()).await; - match self.db.get_url_preview(url) { - Some(preview) => Ok(preview), - None => self.request_url_preview(url).await, + match self.db.get_url_preview(url.as_str()).await { + Ok(preview) => Ok(preview), + Err(_) => self.request_url_preview(url).await, } } #[implement(Service)] -async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> { - if let Ok(ip) = IPAddress::parse(url) { +async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> { + if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) { if !self.services.globals.valid_cidr_range(&ip) { return Err!(BadServerResponse("Requesting from this address is forbidden")); } } let client = &self.services.client.url_preview; - let response = client.head(url).send().await?; + let response = client.head(url.as_str()).send().await?; if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { @@ -111,12 +111,12 @@ async fn request_url_preview(&self, url: &str) -> Result<UrlPreviewData> { return Err!(Request(Unknown("Unknown Content-Type"))); }; let data = match content_type { - html if html.starts_with("text/html") => self.download_html(url).await?, - img if img.starts_with("image/") => self.download_image(url).await?, + html if html.starts_with("text/html") => self.download_html(url.as_str()).await?, + img if img.starts_with("image/") => self.download_image(url.as_str()).await?, _ => return Err!(Request(Unknown("Unsupported Content-Type"))), }; - self.set_url_preview(url, &data).await?; + self.set_url_preview(url.as_str(), &data).await?; Ok(data) } @@ -159,15 +159,7 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> { } #[implement(Service)] -pub fn url_preview_allowed(&self, url_str: &str) -> bool { - let url: Url = match Url::parse(url_str) { - Ok(u) => u, - Err(e) => { - warn!("Failed to parse URL from a str: {}", e); - return false; - }, - }; - +pub fn url_preview_allowed(&self, url: &Url) -> bool { if ["http", "https"] .iter() .all(|&scheme| scheme != url.scheme().to_lowercase()) diff --git a/src/service/media/remote.rs b/src/service/media/remote.rs index 59846b8ee1f6d7a123d2c0f6e86fb09ffd47689e..1c6c9ca0249229d3ccbd0ab8e2e9a5ecef0a97b8 100644 --- a/src/service/media/remote.rs +++ b/src/service/media/remote.rs @@ -382,8 +382,7 @@ fn check_fetch_authorized(&self, mxc: &Mxc<'_>) -> Result<()> { .server .config .prevent_media_downloads_from - .iter() - .any(|entry| entry == mxc.server_name) + .contains(mxc.server_name) { // we'll lie to the client and say the blocked server's media was not found and // log. the client has no way of telling anyways so this is a security bonus. diff --git a/src/service/media/thumbnail.rs b/src/service/media/thumbnail.rs index 630f7b3b1b9bd017e5a0381a9d66bf13edf821df..04ec03039a39a286f6f93ee1bc34bfaba8e4c98d 100644 --- a/src/service/media/thumbnail.rs +++ b/src/service/media/thumbnail.rs @@ -54,9 +54,9 @@ pub async fn get_thumbnail(&self, mxc: &Mxc<'_>, dim: &Dim) -> Result<Option<Fil // 0, 0 because that's the original file let dim = dim.normalized(); - if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim) { + if let Ok(metadata) = self.db.search_file_metadata(mxc, &dim).await { self.get_thumbnail_saved(metadata).await - } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()) { + } else if let Ok(metadata) = self.db.search_file_metadata(mxc, &Dim::default()).await { self.get_thumbnail_generate(mxc, &dim, metadata).await } else { Ok(None) diff --git a/src/service/migrations.rs b/src/service/migrations.rs new file mode 100644 index 0000000000000000000000000000000000000000..126d3c7ef309b9f24e5c3eb957c8773e1651a670 --- /dev/null +++ b/src/service/migrations.rs @@ -0,0 +1,495 @@ +use std::cmp; + +use conduit::{ + debug, debug_info, debug_warn, error, info, + result::NotFound, + utils::{ + stream::{TryExpect, TryIgnore}, + IterStream, ReadyExt, + }, + warn, Err, Result, +}; +use futures::{FutureExt, StreamExt}; +use itertools::Itertools; +use ruma::{ + events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, + push::Ruleset, + OwnedUserId, UserId, +}; + +use crate::{media, Services}; + +/// The current schema version. +/// - If database is opened at greater version we reject with error. The +/// software must be updated for backward-incompatible changes. +/// - If database is opened at lesser version we apply migrations up to this. +/// Note that named-feature migrations may also be performed when opening at +/// equal or lesser version. These are expected to be backward-compatible. +pub(crate) const DATABASE_VERSION: u64 = 13; + +/// Conduit's database version. +/// +/// Conduit bumped the database version to 16, but did not introduce any +/// breaking changes. Their database migrations are extremely fragile and risky, +/// and also do not really apply to us, so just to retain Conduit -> conduwuit +/// compatibility we'll check for both versions. +pub(crate) const CONDUIT_DATABASE_VERSION: u64 = 16; + +pub(crate) async fn migrations(services: &Services) -> Result<()> { + let users_count = services.users.count().await; + + // Matrix resource ownership is based on the server name; changing it + // requires recreating the database from scratch. + if users_count > 0 { + let conduit_user = &services.globals.server_user; + if !services.users.exists(conduit_user).await { + error!("The {conduit_user} server user does not exist, and the database is not new."); + return Err!(Database( + "Cannot reuse an existing database after changing the server name, please delete the old one first.", + )); + } + } + + if users_count > 0 { + migrate(services).await + } else { + fresh(services).await + } +} + +async fn fresh(services: &Services) -> Result<()> { + let db = &services.db; + + services + .globals + .db + .bump_database_version(DATABASE_VERSION)?; + + db["global"].insert(b"feat_sha256_media", []); + db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); + db["global"].insert(b"fix_referencedevents_missing_sep", []); + + // Create the admin room and server user on first run + crate::admin::create_admin_room(services).boxed().await?; + + warn!("Created new RocksDB database with version {DATABASE_VERSION}"); + + Ok(()) +} + +/// Apply any migrations +async fn migrate(services: &Services) -> Result<()> { + let db = &services.db; + let config = &services.server.config; + + if services.globals.db.database_version().await < 11 { + return Err!(Database( + "Database schema version {} is no longer supported", + services.globals.db.database_version().await + )); + } + + if services.globals.db.database_version().await < 12 { + db_lt_12(services).await?; + } + + // This migration can be reused as-is anytime the server-default rules are + // updated. + if services.globals.db.database_version().await < 13 { + db_lt_13(services).await?; + } + + if db["global"].get(b"feat_sha256_media").await.is_not_found() { + media::migrations::migrate_sha256_media(services).await?; + } else if config.media_startup_check { + media::migrations::checkup_sha256_media(services).await?; + } + + if db["global"] + .get(b"fix_bad_double_separator_in_state_cache") + .await + .is_not_found() + { + fix_bad_double_separator_in_state_cache(services).await?; + } + + if db["global"] + .get(b"retroactively_fix_bad_data_from_roomuserid_joined") + .await + .is_not_found() + { + retroactively_fix_bad_data_from_roomuserid_joined(services).await?; + } + + if db["global"] + .get(b"fix_referencedevents_missing_sep") + .await + .is_not_found() + { + fix_referencedevents_missing_sep(services).await?; + } + + let version_match = services.globals.db.database_version().await == DATABASE_VERSION + || services.globals.db.database_version().await == CONDUIT_DATABASE_VERSION; + + assert!( + version_match, + "Failed asserting local database version {} is equal to known latest conduwuit database version {}", + services.globals.db.database_version().await, + DATABASE_VERSION, + ); + + { + let patterns = services.globals.forbidden_usernames(); + if !patterns.is_empty() { + services + .users + .stream() + .filter(|user_id| services.users.is_active_local(user_id)) + .ready_for_each(|user_id| { + let matches = patterns.matches(user_id.localpart()); + if matches.matched_any() { + warn!( + "User {} matches the following forbidden username patterns: {}", + user_id.to_string(), + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") + ); + } + }) + .await; + } + } + + { + let patterns = services.globals.forbidden_alias_names(); + if !patterns.is_empty() { + for room_id in services + .rooms + .metadata + .iter_ids() + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await + { + services + .rooms + .alias + .local_aliases_for_room(&room_id) + .ready_for_each(|room_alias| { + let matches = patterns.matches(room_alias.alias()); + if matches.matched_any() { + warn!( + "Room with alias {} ({}) matches the following forbidden room name patterns: {}", + room_alias, + &room_id, + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") + ); + } + }) + .await; + } + } + } + + info!("Loaded RocksDB database with schema version {DATABASE_VERSION}"); + + Ok(()) +} + +async fn db_lt_12(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::<Vec<_>>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let mut account_data: PushRulesEvent = services + .account_data + .get_global(&user, GlobalAccountDataEventType::PushRules) + .await + .expect("Username is invalid"); + + let rules_list = &mut account_data.content.global; + + //content rule + { + let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; + + let rule = rules_list.content.get(content_rule_transformation[0]); + if rule.is_some() { + let mut rule = rule.unwrap().clone(); + content_rule_transformation[1].clone_into(&mut rule.rule_id); + rules_list + .content + .shift_remove(content_rule_transformation[0]); + rules_list.content.insert(rule); + } + } + + //underride rules + { + let underride_rule_transformation = [ + [".m.rules.call", ".m.rule.call"], + [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], + [".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"], + [".m.rules.message", ".m.rule.message"], + [".m.rules.encrypted", ".m.rule.encrypted"], + ]; + + for transformation in underride_rule_transformation { + let rule = rules_list.underride.get(transformation[0]); + if let Some(rule) = rule { + let mut rule = rule.clone(); + transformation[1].clone_into(&mut rule.rule_id); + rules_list.underride.shift_remove(transformation[0]); + rules_list.underride.insert(rule); + } + } + } + + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; + } + + services.globals.db.bump_database_version(12)?; + info!("Migration: 11 -> 12 finished"); + Ok(()) +} + +async fn db_lt_13(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in &services + .users + .list_local_users() + .map(UserId::to_owned) + .collect::<Vec<_>>() + .await + { + let user = match UserId::parse_with_server_name(username.as_str(), &config.server_name) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let mut account_data: PushRulesEvent = services + .account_data + .get_global(&user, GlobalAccountDataEventType::PushRules) + .await + .expect("Username is invalid"); + + let user_default_rules = Ruleset::server_default(&user); + account_data + .content + .global + .update_with_server_default(user_default_rules); + + services + .account_data + .update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + ) + .await?; + } + + services.globals.db.bump_database_version(13)?; + info!("Migration: 12 -> 13 finished"); + Ok(()) +} + +async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> { + warn!("Fixing bad double separator in state_cache roomuserid_joined"); + + let db = &services.db; + let roomuserid_joined = &db["roomuserid_joined"]; + let _cork = db.cork_and_sync(); + + let mut iter_count: usize = 0; + roomuserid_joined + .raw_stream() + .ignore_err() + .ready_for_each(|(key, value)| { + let mut key = key.to_vec(); + iter_count = iter_count.saturating_add(1); + debug_info!(%iter_count); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); + + if key + .iter() + .get(first_sep_index..=first_sep_index.saturating_add(1)) + .copied() + .collect_vec() + == vec![0xFF, 0xFF] + { + debug_warn!("Found bad key: {key:?}"); + roomuserid_joined.remove(&key); + + key.remove(first_sep_index); + debug_warn!("Fixed key: {key:?}"); + roomuserid_joined.insert(&key, value); + } + }) + .await; + + db.db.cleanup()?; + db["global"].insert(b"fix_bad_double_separator_in_state_cache", []); + + info!("Finished fixing"); + Ok(()) +} + +async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> { + warn!("Retroactively fixing bad data from broken roomuserid_joined"); + + let db = &services.db; + let _cork = db.cork_and_sync(); + + let room_ids = services + .rooms + .metadata + .iter_ids() + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + .await; + + for room_id in &room_ids { + debug_info!("Fixing room {room_id}"); + + let users_in_room: Vec<OwnedUserId> = services + .rooms + .state_cache + .room_members(room_id) + .map(ToOwned::to_owned) + .collect() + .await; + + let joined_members = users_in_room + .iter() + .stream() + .filter(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) + }) + .collect::<Vec<_>>() + .await; + + let non_joined_members = users_in_room + .iter() + .stream() + .filter(|user_id| { + services + .rooms + .state_accessor + .get_member(room_id, user_id) + .map(|member| member.map_or(false, |member| member.membership == MembershipState::Join)) + }) + .collect::<Vec<_>>() + .await; + + for user_id in &joined_members { + debug_info!("User is joined, marking as joined"); + services.rooms.state_cache.mark_as_joined(user_id, room_id); + } + + for user_id in &non_joined_members { + debug_info!("User is left or banned, marking as left"); + services.rooms.state_cache.mark_as_left(user_id, room_id); + } + } + + for room_id in &room_ids { + debug_info!( + "Updating joined count for room {room_id} to fix servers in room after correcting membership states" + ); + + services + .rooms + .state_cache + .update_joined_count(room_id) + .await; + } + + db.db.cleanup()?; + db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); + + info!("Finished fixing"); + Ok(()) +} + +async fn fix_referencedevents_missing_sep(services: &Services) -> Result { + warn!("Fixing missing record separator between room_id and event_id in referencedevents"); + + let db = &services.db; + let cork = db.cork_and_sync(); + + let referencedevents = db["referencedevents"].clone(); + + let totals: (usize, usize) = (0, 0); + let (total, fixed) = referencedevents + .raw_stream() + .expect_ok() + .enumerate() + .ready_fold(totals, |mut a, (i, (key, val))| { + debug_assert!(val.is_empty(), "expected no value"); + + let has_sep = key.contains(&database::SEP); + + if !has_sep { + let key_str = std::str::from_utf8(key).expect("key not utf-8"); + let room_id_len = key_str.find('$').expect("missing '$' in key"); + let (room_id, event_id) = key_str.split_at(room_id_len); + debug!(?a, "fixing {room_id}, {event_id}"); + + let new_key = (room_id, event_id); + referencedevents.put_raw(new_key, val); + referencedevents.remove(key); + } + + a.0 = cmp::max(i, a.0); + a.1 = a.1.saturating_add((!has_sep).into()); + a + }) + .await; + + drop(cork); + info!(?total, ?fixed, "Fixed missing record separators in 'referencedevents'."); + + db["global"].insert(b"fix_referencedevents_missing_sep", []); + db.db.cleanup() +} diff --git a/src/service/mod.rs b/src/service/mod.rs index f588a542054c43f897ee37e3887ac5f72cc1e823..c7dcc0c611bff70519c790ca00140c2f467ab21d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,7 +1,7 @@ -#![recursion_limit = "192"] #![allow(refining_impl_trait)] mod manager; +mod migrations; mod service; pub mod services; @@ -19,6 +19,7 @@ pub mod rooms; pub mod sending; pub mod server_keys; +pub mod sync; pub mod transaction_ids; pub mod uiaa; pub mod updates; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index ec036b3d6e3999043e6852235a7290169606063f..68b2c3feb378ae47b1d877a3453a38e61e7e5096 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,8 +1,13 @@ use std::sync::Arc; -use conduit::{debug_warn, utils, Error, Result}; -use database::Map; -use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; +use conduit::{ + debug_warn, utils, + utils::{stream::TryIgnore, ReadyExt}, + Result, +}; +use database::{Deserialized, Json, Map}; +use futures::Stream; +use ruma::{events::presence::PresenceEvent, presence::PresenceState, UInt, UserId}; use super::Presence; use crate::{globals, users, Dep}; @@ -31,39 +36,35 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - pub fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEvent)>> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - - let key = presenceid_key(count, user_id); - self.presenceid_presence - .get(&key)? - .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok(( - count, - Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, - )) - }) - .transpose() - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result<(u64, PresenceEvent)> { + let count = self + .userid_presenceid + .get(user_id) + .await + .deserialized::<u64>()?; + + let key = presenceid_key(count, user_id); + let bytes = self.presenceid_presence.get(&key).await?; + let event = Presence::from_json_bytes(&bytes)? + .to_presence_event(user_id, &self.services.users) + .await; + + Ok((count, event)) } - pub(super) fn set_presence( + pub(super) async fn set_presence( &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>, status_msg: Option<String>, ) -> Result<()> { - let last_presence = self.get_presence(user_id)?; + let last_presence = self.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some(ref presence) => presence.1.content.presence != *presence_state, + Err(_) => true, + Ok(ref presence) => presence.1.content.presence != *presence_state, }; let status_msg_changed = match last_presence { - None => true, - Some(ref last_presence) => { + Err(_) => true, + Ok(ref last_presence) => { let old_msg = last_presence .1 .content @@ -79,8 +80,8 @@ pub(super) fn set_presence( let now = utils::millis_since_unix_epoch(); let last_last_active_ts = match last_presence { - None => 0, - Some((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), + Err(_) => 0, + Ok((_, ref presence)) => now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()), }; let last_active_ts = match last_active_ago { @@ -90,12 +91,7 @@ pub(super) fn set_presence( // TODO: tighten for state flicker? if !status_msg_changed && !state_changed && last_active_ts < last_last_active_ts { - debug_warn!( - "presence spam {:?} last_active_ts:{:?} < {:?}", - user_id, - last_active_ts, - last_last_active_ts - ); + debug_warn!("presence spam {user_id:?} last_active_ts:{last_active_ts:?} < {last_last_active_ts:?}",); return Ok(()); } @@ -111,58 +107,62 @@ pub(super) fn set_presence( last_active_ts, status_msg, ); + let count = self.services.globals.next_count()?; let key = presenceid_key(count, user_id); - self.presenceid_presence - .insert(&key, &presence.to_json_bytes()?)?; + self.presenceid_presence.raw_put(key, Json(presence)); + self.userid_presenceid.raw_put(user_id, count); - self.userid_presenceid - .insert(user_id.as_bytes(), &count.to_be_bytes())?; - - if let Some((last_count, _)) = last_presence { + if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); - self.presenceid_presence.remove(&key)?; + self.presenceid_presence.remove(&key); } Ok(()) } - pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> { - if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { - let count = utils::u64_from_bytes(&count_bytes) - .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; - let key = presenceid_key(count, user_id); - self.presenceid_presence.remove(&key)?; - self.userid_presenceid.remove(user_id.as_bytes())?; - } + pub(super) async fn remove_presence(&self, user_id: &UserId) { + let Ok(count) = self + .userid_presenceid + .get(user_id) + .await + .deserialized::<u64>() + else { + return; + }; - Ok(()) + let key = presenceid_key(count, user_id); + self.presenceid_presence.remove(&key); + self.userid_presenceid.remove(user_id); } - pub fn presence_since<'a>(&'a self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + 'a> { - Box::new( - self.presenceid_presence - .iter() - .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, Vec<u8>)> { - let (count, user_id) = presenceid_parse(&key)?; - Ok((user_id.to_owned(), count, presence_bytes)) - }) - .filter(move |(_, count, _)| *count > since), - ) + #[inline] + pub(super) fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ { + self.presenceid_presence + .raw_stream() + .ignore_err() + .ready_filter_map(move |(key, presence)| { + let (count, user_id) = presenceid_parse(key).ok()?; + (count > since).then_some((user_id, count, presence)) + }) } } #[inline] fn presenceid_key(count: u64, user_id: &UserId) -> Vec<u8> { - [count.to_be_bytes().to_vec(), user_id.as_bytes().to_vec()].concat() + let cap = size_of::<u64>().saturating_add(user_id.as_bytes().len()); + let mut key = Vec::with_capacity(cap); + key.extend_from_slice(&count.to_be_bytes()); + key.extend_from_slice(user_id.as_bytes()); + key } #[inline] fn presenceid_parse(key: &[u8]) -> Result<(u64, &UserId)> { let (count, user_id) = key.split_at(8); let user_id = user_id_from_bytes(user_id)?; - let count = utils::u64_from_bytes(count).unwrap(); + let count = utils::u64_from_u8(count); Ok((count, user_id)) } diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index a54a6d7c5b19e6777c6a22af7860e3f702f46c56..b2106f3f7e037ff30ac26b7dc25a178888c92aeb 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -4,8 +4,8 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, Error, Result, Server}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use conduit::{checked, debug, error, result::LogErr, Error, Result, Server}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use tokio::{sync::Mutex, time::sleep}; @@ -55,12 +55,13 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { async fn worker(self: Arc<Self>) -> Result<()> { let mut presence_timers = FuturesUnordered::new(); let receiver = self.timer_receiver.lock().await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); + while !receiver.is_closed() { tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => { + self.process_presence_timer(&user_id).await.log_err().ok(); + }, event = receiver.recv_async() => match event { - Err(_e) => return Ok(()), + Err(_) => break, Ok((user_id, timeout)) => { debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); presence_timers.push(presence_timer(user_id, timeout)); @@ -68,6 +69,8 @@ async fn worker(self: Arc<Self>) -> Result<()> { }, } } + + Ok(()) } fn interrupt(&self) { @@ -82,28 +85,27 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Returns the latest presence event for the given user. #[inline] - pub fn get_presence(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> { - if let Some((_, presence)) = self.db.get_presence(user_id)? { - Ok(Some(presence)) - } else { - Ok(None) - } + pub async fn get_presence(&self, user_id: &UserId) -> Result<PresenceEvent> { + self.db + .get_presence(user_id) + .map_ok(|(_, presence)| presence) + .await } /// Pings the presence of the given user in the given room, setting the /// specified state. - pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { + pub async fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Result<()> { const REFRESH_TIMEOUT: u64 = 60 * 25 * 1000; - let last_presence = self.db.get_presence(user_id)?; + let last_presence = self.db.get_presence(user_id).await; let state_changed = match last_presence { - None => true, - Some((_, ref presence)) => presence.content.presence != *new_state, + Err(_) => true, + Ok((_, ref presence)) => presence.content.presence != *new_state, }; let last_last_active_ago = match last_presence { - None => 0_u64, - Some((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), + Err(_) => 0_u64, + Ok((_, ref presence)) => presence.content.last_active_ago.unwrap_or_default().into(), }; if !state_changed && last_last_active_ago < REFRESH_TIMEOUT { @@ -111,17 +113,18 @@ pub fn ping_presence(&self, user_id: &UserId, new_state: &PresenceState) -> Resu } let status_msg = match last_presence { - Some((_, ref presence)) => presence.content.status_msg.clone(), - None => Some(String::new()), + Ok((_, ref presence)) => presence.content.status_msg.clone(), + Err(_) => Some(String::new()), }; let last_active_ago = UInt::new(0); let currently_active = *new_state == PresenceState::Online; self.set_presence(user_id, new_state, Some(currently_active), last_active_ago, status_msg) + .await } /// Adds a presence event which will be saved until a new event replaces it. - pub fn set_presence( + pub async fn set_presence( &self, user_id: &UserId, state: &PresenceState, currently_active: Option<bool>, last_active_ago: Option<UInt>, status_msg: Option<String>, ) -> Result<()> { @@ -131,7 +134,8 @@ pub fn set_presence( }; self.db - .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; + .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg) + .await?; if self.timeout_remote_users || self.services.globals.user_is_local(user_id) { let timeout = match presence_state { @@ -154,28 +158,32 @@ pub fn set_presence( /// /// TODO: Why is this not used? #[allow(dead_code)] - pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } + pub async fn remove_presence(&self, user_id: &UserId) { self.db.remove_presence(user_id).await } /// Returns the most recent presence updates that happened after the event /// with id `since`. - #[inline] - pub fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId, u64, Vec<u8>)> + '_> { + pub fn presence_since(&self, since: u64) -> impl Stream<Item = (&UserId, u64, &[u8])> + Send + '_ { self.db.presence_since(since) } - pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> { + #[inline] + pub async fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> { let presence = Presence::from_json_bytes(bytes)?; - presence.to_presence_event(user_id, &self.services.users) + let event = presence + .to_presence_event(user_id, &self.services.users) + .await; + + Ok(event) } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + async fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; let mut status_msg = None; - let presence_event = self.get_presence(user_id)?; + let presence_event = self.get_presence(user_id).await; - if let Some(presence_event) = presence_event { + if let Ok(presence_event) = presence_event { presence_state = presence_event.content.presence; last_active_ago = presence_event.content.last_active_ago; status_msg = presence_event.content.status_msg; @@ -192,7 +200,8 @@ fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { ); if let Some(new_state) = new_state { - self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg) + .await?; } Ok(()) diff --git a/src/service/presence/presence.rs b/src/service/presence/presence.rs index 570008f29f4be3cca669d9edcaa7bc5629ce4ab3..c4372003416456f9cc0f234f744cdd615474c868 100644 --- a/src/service/presence/presence.rs +++ b/src/service/presence/presence.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use conduit::{utils, Error, Result}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -37,12 +35,8 @@ pub(super) fn from_json_bytes(bytes: &[u8]) -> Result<Self> { serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) } - pub(super) fn to_json_bytes(&self) -> Result<Vec<u8>> { - serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON")) - } - /// Creates a PresenceEvent from available data. - pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc<users::Service>) -> Result<PresenceEvent> { + pub(super) async fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> PresenceEvent { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -50,16 +44,16 @@ pub(super) fn to_presence_event(&self, user_id: &UserId, users: &Arc<users::Serv Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts))) }; - Ok(PresenceEvent { + PresenceEvent { sender: user_id.to_owned(), content: PresenceEventContent { presence: self.state.clone(), status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: users.displayname(user_id)?, - avatar_url: users.avatar_url(user_id)?, + displayname: users.displayname(user_id).await.ok(), + avatar_url: users.avatar_url(user_id).await.ok(), }, - }) + } } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs deleted file mode 100644 index f97343341c8d687668a825d88a89f08e0a26fef3..0000000000000000000000000000000000000000 --- a/src/service/pusher/data.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, -}; - -pub(super) struct Data { - senderkey_pusher: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - senderkey_pusher: db["senderkey_pusher"].clone(), - } - } - - pub(super) fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - match pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher - .insert(&key, &serde_json::to_vec(pusher).expect("Pusher is valid JSON value"))?; - Ok(()) - }, - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher.remove(&key).map_err(Into::into) - }, - } - } - - pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xFF); - senderkey.extend_from_slice(pushkey.as_bytes()); - - self.senderkey_pusher - .get(&senderkey)? - .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .transpose() - } - - pub(super) fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) - .collect() - } - - pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xFF); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - - Ok(push_key_string) - })) - } -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index de87264c9ae421842b946f2858bbfe307627de0d..2b90319e9646a5523bb17047ee52e0e86f90ef6f 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,9 +1,13 @@ -mod data; - use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_error, err, trace, utils::string_from_bytes, warn, Err, PduEvent, Result}; +use conduit::{ + debug_error, err, trace, + utils::{stream::TryIgnore, string_from_bytes}, + Err, PduEvent, Result, +}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; +use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -22,12 +26,11 @@ uint, RoomId, UInt, UserId, }; -use self::data::Data; use crate::{client, globals, rooms, users, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { @@ -38,9 +41,16 @@ struct Services { users: Dep<users::Service>, } +struct Data { + senderkey_pusher: Arc<Map>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + db: Data { + senderkey_pusher: args.db["senderkey_pusher"].clone(), + }, services: Services { globals: args.depend::<globals::Service>("globals"), client: args.depend::<client::Service>("client"), @@ -48,7 +58,6 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), users: args.depend::<users::Service>("users"), }, - db: Data::new(args.db), })) } @@ -56,19 +65,46 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { - self.db.set_pusher(sender, pusher) + pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) { + match pusher { + set_pusher::v3::PusherAction::Post(data) => { + let key = (sender, &data.pusher.ids.pushkey); + self.db.senderkey_pusher.put(key, Json(pusher)); + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let key = (sender, &ids.pushkey); + self.db.senderkey_pusher.del(key); + }, + } } - pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { - self.db.get_pusher(sender, pushkey) + pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Pusher> { + let senderkey = (sender, pushkey); + self.db + .senderkey_pusher + .qry(&senderkey) + .await + .deserialized() } - pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) } + pub async fn get_pushers(&self, sender: &UserId) -> Vec<Pusher> { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, pusher): (Ignore, Pusher)| pusher) + .collect() + .await + } - #[must_use] - pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + '_> { - self.db.get_pushkeys(sender) + pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream<Item = &str> + Send + 'a { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, pushkey): (Ignore, &str)| pushkey) } #[tracing::instrument(skip(self, dest, request))] @@ -161,15 +197,18 @@ pub async fn send_push_notice( let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .and_then(|ev| { serde_json::from_str(ev.content.get()) - .map_err(|e| err!(Database("invalid m.room.power_levels event: {e:?}"))) + .map_err(|e| err!(Database(error!("invalid m.room.power_levels event: {e:?}")))) }) - .transpose()? .unwrap_or_default(); - for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? { + for action in self + .get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id) + .await? + { let n = match action { Action::Notify => true, Action::SetTweak(tweak) => { @@ -197,7 +236,7 @@ pub async fn send_push_notice( } #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] - pub fn get_actions<'a>( + pub async fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId, ) -> Result<&'a [Action]> { @@ -207,21 +246,27 @@ pub fn get_actions<'a>( notifications: power_levels.notifications.clone(), }; + let room_joined_count = self + .services + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(1) + .try_into() + .unwrap_or_else(|_| uint!(0)); + + let user_display_name = self + .services + .users + .displayname(user) + .await + .unwrap_or_else(|_| user.localpart().to_owned()); + let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), - member_count: UInt::try_from( - self.services - .state_cache - .room_joined_count(room_id)? - .unwrap_or(1), - ) - .unwrap_or_else(|_| uint!(0)), + member_count: room_joined_count, user_id: user.to_owned(), - user_display_name: self - .services - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), + user_display_name, power_levels: Some(power_levels), }; @@ -278,9 +323,21 @@ async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, e notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = self.services.users.displayname(&event.sender)?; - - notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender).await.ok(); + + notifi.room_name = self + .services + .state_accessor + .get_name(&event.room_id) + .await + .ok(); + + notifi.room_alias = self + .services + .state_accessor + .get_canonical_alias(&event.room_id) + .await + .ok(); self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 07d9a0fae4971e674e2bcf4c9a95775b556def59..5dc03d141b30a4bbd2bb1fe30f16dec0de179acd 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -9,43 +9,42 @@ use ipaddress::IPAddress; use ruma::ServerName; -use crate::resolver::{ - cache::{CachedDest, CachedOverride}, - fed::{add_port_to_hostname, get_ip_with_port, FedDest}, +use super::{ + cache::{CachedDest, CachedOverride, MAX_IPS}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest, PortString}, }; #[derive(Clone, Debug)] pub(crate) struct ActualDest { pub(crate) dest: FedDest, pub(crate) host: String, - pub(crate) string: String, pub(crate) cached: bool, } +impl ActualDest { + #[inline] + pub(crate) fn string(&self) -> String { self.dest.https_string() } +} + impl super::Service { #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> { - let cached; - let cached_result = self.get_cached_destination(server_name); + let (result, cached) = if let Some(result) = self.get_cached_destination(server_name) { + (result, true) + } else { + self.validate_dest(server_name)?; + (self.resolve_actual_dest(server_name, true).await?, false) + }; let CachedDest { dest, host, .. - } = if let Some(result) = cached_result { - cached = true; - result - } else { - cached = false; - self.validate_dest(server_name)?; - self.resolve_actual_dest(server_name, true).await? - }; + } = result; - let string = dest.clone().into_https_string(); Ok(ActualDest { dest, host, - string, cached, }) } @@ -78,18 +77,18 @@ pub async fn resolve_actual_dest(&self, dest: &ServerName, cache: bool) -> Resul let host = if let Ok(addr) = host.parse::<SocketAddr>() { FedDest::Literal(addr) } else if let Ok(addr) = host.parse::<IpAddr>() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) + FedDest::Named(addr.to_string(), FedDest::default_port()) } else if let Some(pos) = host.find(':') { let (host, port) = host.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named(host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port())) } else { - FedDest::Named(host, ":8448".to_owned()) + FedDest::Named(host, FedDest::default_port()) }; debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); Ok(CachedDest { dest: actual_dest, - host: host.into_uri_string(), + host: host.uri_string(), expire: CachedDest::default_expire(), }) } @@ -104,12 +103,15 @@ async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Res let (host, port) = dest.as_str().split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result<FedDest> { debug!("3: A .well-known file is available"); - *host = add_port_to_hostname(&delegated).into_uri_string(); + *host = add_port_to_hostname(&delegated).uri_string(); match get_ip_with_port(&delegated) { Some(host_and_port) => Self::actual_dest_3_1(host_and_port), None => { @@ -137,7 +139,10 @@ async fn actual_dest_3_2(&self, cache: bool, delegated: String, pos: usize) -> R let (host, port) = delegated.split_at(pos); self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache) .await?; - Ok(FedDest::Named(host.to_owned(), port.to_owned())) + Ok(FedDest::Named( + host.to_owned(), + port.try_into().unwrap_or_else(|_| FedDest::default_port()), + )) } async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result<FedDest> { @@ -146,7 +151,13 @@ async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDe self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(delegated, format!(":{port}"))) + Ok(FedDest::Named( + delegated, + format!(":{port}") + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(&delegated)) } @@ -165,7 +176,11 @@ async fn actual_dest_4(&self, host: &str, cache: bool, overrider: FedDest) -> Re self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache) .await?; if let Some(port) = force_port { - Ok(FedDest::Named(host.to_owned(), format!(":{port}"))) + let port = format!(":{port}"); + Ok(FedDest::Named( + host.to_owned(), + PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()), + )) } else { Ok(add_port_to_hostname(host)) } @@ -193,7 +208,7 @@ async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { .send() .await; - trace!("response: {:?}", response); + trace!("response: {response:?}"); if let Err(e) = &response { debug!("error: {e:?}"); return Ok(None); @@ -206,7 +221,7 @@ async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { } let text = response.text().await?; - trace!("response text: {:?}", text); + trace!("response text: {text:?}"); if text.len() >= 12288 { debug_warn!("response contains junk"); return Ok(None); @@ -225,7 +240,7 @@ async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { return Ok(None); } - debug_info!("{:?} found at {:?}", dest, m_server); + debug_info!("{dest:?} found at {m_server:?}"); Ok(Some(m_server.to_owned())) } @@ -251,9 +266,9 @@ async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, p } self.set_cached_override( - overname.to_owned(), + overname, CachedOverride { - ips: override_ip.iter().collect(), + ips: override_ip.into_iter().take(MAX_IPS).collect(), port, expire: CachedOverride::default_expire(), }, @@ -270,7 +285,10 @@ fn handle_successful_srv(srv: &SrvLookup) -> Option<FedDest> { srv.iter().next().map(|result| { FedDest::Named( result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), + format!(":{}", result.port()) + .as_str() + .try_into() + .unwrap_or_else(|_| FedDest::default_port()), ) }) } diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 465b59855f9e84e7e8d852245ae9992d9dbc62f9..a13399dc86941089b5bad1ef37131daa040e1f16 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -5,6 +5,7 @@ time::SystemTime, }; +use arrayvec::ArrayVec; use conduit::{trace, utils::rand}; use ruma::{OwnedServerName, ServerName}; @@ -24,7 +25,7 @@ pub struct CachedDest { #[derive(Clone, Debug)] pub struct CachedOverride { - pub ips: Vec<IpAddr>, + pub ips: IpAddrs, pub port: u16, pub expire: SystemTime, } @@ -32,6 +33,9 @@ pub struct CachedOverride { pub type WellKnownMap = HashMap<OwnedServerName, CachedDest>; pub type TlsNameMap = HashMap<String, CachedOverride>; +pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>; +pub(crate) const MAX_IPS: usize = 3; + impl Cache { pub(super) fn new() -> Arc<Self> { Arc::new(Self { @@ -61,13 +65,13 @@ pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> { .cloned() } - pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> { + pub fn set_cached_override(&self, name: &str, over: CachedOverride) -> Option<CachedOverride> { trace!(?name, ?over, "set cached override"); self.cache .overrides .write() .expect("locked for writing") - .insert(name, over) + .insert(name.into(), over) } #[must_use] diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index b77bbb84fe9e83d14359e6298d450e7a79bd890d..d3e9f5c93fba2f90f6856b79e5aa1bc405ed03da 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -1,15 +1,11 @@ -use std::{ - future, iter, - net::{IpAddr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use conduit::{err, Result, Server}; +use futures::FutureExt; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use super::cache::Cache; +use super::cache::{Cache, CachedOverride}; pub struct Resolver { pub(crate) resolver: Arc<TokioAsyncResolver>, @@ -21,6 +17,8 @@ pub(crate) struct Hooked { cache: Arc<Cache>, } +type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>; + impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> { @@ -82,12 +80,12 @@ pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self> } impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() } } impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let cached = self + let cached: Option<CachedOverride> = self .cache .overrides .read() @@ -95,35 +93,28 @@ fn resolve(&self, name: Name) -> Resolving { .get(name.as_str()) .cloned(); - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } + cached.map_or_else( + || resolve_to_reqwest(self.resolver.clone(), name).boxed(), + |cached| cached_to_reqwest(cached).boxed(), + ) } } -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name - .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} +async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { + let addrs = cached + .ips + .into_iter() + .map(move |ip| SocketAddr::new(ip, cached.port)); -fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); + Ok(Box::new(addrs)) +} - let results: Addrs = Box::new(results); +async fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> ResolvingResult { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); - Ok(results) - }) + Ok(Box::new(results)) } diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs index 10cbbbdd0b9cc94e1d727d4f9ebebe280876c683..9c348b47ef04aeb5da63662075dc1cf4369f58b4 100644 --- a/src/service/resolver/fed.rs +++ b/src/service/resolver/fed.rs @@ -1,14 +1,22 @@ use std::{ + borrow::Cow, fmt, net::{IpAddr, SocketAddr}, }; +use arrayvec::ArrayString; + #[derive(Clone, Debug, PartialEq, Eq)] pub enum FedDest { Literal(SocketAddr), - Named(String, String), + Named(String, PortString), } +/// numeric or service-name +pub type PortString = ArrayString<16>; + +const DEFAULT_PORT: &str = ":8448"; + pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { if let Ok(dest) = dest_str.parse::<SocketAddr>() { Some(FedDest::Literal(dest)) @@ -19,34 +27,38 @@ pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { } } -pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), +pub(crate) fn add_port_to_hostname(dest: &str) -> FedDest { + let (host, port) = match dest.find(':') { + None => (dest, DEFAULT_PORT), + Some(pos) => dest.split_at(pos), }; - FedDest::Named(host.to_owned(), port.to_owned()) + FedDest::Named( + host.to_owned(), + PortString::from(port).unwrap_or_else(|_| FedDest::default_port()), + ) } impl FedDest { - pub(crate) fn into_https_string(self) -> String { + pub(crate) fn https_string(&self) -> String { match self { Self::Literal(addr) => format!("https://{addr}"), Self::Named(host, port) => format!("https://{host}{port}"), } } - pub(crate) fn into_uri_string(self) -> String { + pub(crate) fn uri_string(&self) -> String { match self { Self::Literal(addr) => addr.to_string(), Self::Named(host, port) => format!("{host}{port}"), } } - pub(crate) fn hostname(&self) -> String { + #[inline] + pub(crate) fn hostname(&self) -> Cow<'_, str> { match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), + Self::Literal(addr) => addr.ip().to_string().into(), + Self::Named(host, _) => host.into(), } } @@ -58,13 +70,12 @@ pub(crate) fn port(&self) -> Option<u16> { Self::Named(_, port) => port[1..].parse().ok(), } } + + #[inline] + #[must_use] + pub fn default_port() -> PortString { PortString::from(DEFAULT_PORT).expect("default port string") } } impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.uri_string().as_str()) } } diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs index 55cf0345d78c20b3cc64ea4697e59c457d563b9b..870f5eabfd9059b789a47eab6b40ae6d6621abdf 100644 --- a/src/service/resolver/tests.rs +++ b/src/service/resolver/tests.rs @@ -30,7 +30,7 @@ fn ips_keep_custom_ports() { fn hostnames_get_default_ports() { assert_eq!( add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) + FedDest::Named(String::from("example.com"), ":8448".try_into().unwrap()) ); } @@ -38,6 +38,6 @@ fn hostnames_get_default_ports() { fn hostnames_keep_custom_ports() { assert_eq!( add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) + FedDest::Named(String::from("example.com"), ":1337".try_into().unwrap()) ); } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs deleted file mode 100644 index efd2b5b76562df2aa2382414a685f96ba6076235..0000000000000000000000000000000000000000 --- a/src/service/rooms/alias/data.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - alias_userid: Arc<Map>, - alias_roomid: Arc<Map>, - aliasid_alias: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - alias_userid: db["alias_userid"].clone(), - alias_roomid: db["alias_roomid"].clone(), - aliasid_alias: db["aliasid_alias"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - }, - } - } - - pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - // Comes first as we don't want a stuck alias - self.alias_userid - .insert(alias.alias().as_bytes(), user_id.as_bytes())?; - - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xFF); - aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - - Ok(()) - } - - pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xFF); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - - self.alias_roomid.remove(alias.alias().as_bytes())?; - - self.alias_userid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid.")); - } - - Ok(()) - } - - pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>> { - self.alias_userid - .get(alias.alias().as_bytes())? - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in alias_roomid is invalid.")) - }) - .transpose() - } - - pub(super) fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } - - pub(super) fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - - Ok((room_id, room_alias_localpart)) - }), - ) - } -} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index f2e01ab548b00f03f9727cdd5cb57503edf5176d..0cdec8eeb7e0200941e2dc255486b3e1a2de57d1 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,19 +1,23 @@ -mod data; mod remote; use std::sync::Arc; -use conduit::{err, Error, Result}; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + Err, Error, Result, +}; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{Stream, StreamExt}; use ruma::{ api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, }, - OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use self::data::Data; use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep}; pub struct Service { @@ -21,6 +25,12 @@ pub struct Service { services: Services, } +struct Data { + alias_userid: Arc<Map>, + alias_roomid: Arc<Map>, + aliasid_alias: Arc<Map>, +} + struct Services { admin: Dep<admin::Service>, appservice: Dep<appservice::Service>, @@ -32,7 +42,11 @@ struct Services { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + alias_userid: args.db["alias_userid"].clone(), + alias_roomid: args.db["alias_roomid"].clone(), + aliasid_alias: args.db["aliasid_alias"].clone(), + }, services: Services { admin: args.depend::<admin::Service>("admin"), appservice: args.depend::<appservice::Service>("appservice"), @@ -50,122 +64,172 @@ impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { - Err(Error::BadRequest( + return Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", - )) - } else { - self.db.set_alias(alias, room_id, user_id) + )); } + + // Comes first as we don't want a stuck alias + self.db + .alias_userid + .insert(alias.alias().as_bytes(), user_id.as_bytes()); + + self.db + .alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes()); + + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + self.db.aliasid_alias.insert(&aliasid, alias.as_bytes()); + + Ok(()) } #[tracing::instrument(skip(self))] pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> { - if self.user_can_remove_alias(alias, user_id).await? { - self.db.remove_alias(alias) - } else { - Err(Error::BadRequest( - ErrorKind::forbidden(), - "User is not permitted to remove this alias.", - )) + if !self.user_can_remove_alias(alias, user_id).await? { + return Err!(Request(Forbidden("User is not permitted to remove this alias."))); } + + let alias = alias.alias(); + let Ok(room_id) = self.db.alias_roomid.get(&alias).await else { + return Err!(Request(NotFound("Alias does not exist or is invalid."))); + }; + + let prefix = (&room_id, Interfix); + self.db + .aliasid_alias + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.aliasid_alias.remove(key)) + .await; + + self.db.alias_roomid.remove(alias.as_bytes()); + self.db.alias_userid.remove(alias.as_bytes()); + + Ok(()) } + #[inline] pub async fn resolve(&self, room: &RoomOrAliasId) -> Result<OwnedRoomId> { + self.resolve_with_servers(room, None) + .await + .map(|(room_id, _)| room_id) + } + + pub async fn resolve_with_servers( + &self, room: &RoomOrAliasId, servers: Option<Vec<OwnedServerName>>, + ) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> { if room.is_room_id() { - let room_id: &RoomId = &RoomId::parse(room).expect("valid RoomId"); - Ok(room_id.to_owned()) + let room_id = RoomId::parse(room).expect("valid RoomId"); + Ok((room_id, servers.unwrap_or_default())) } else { - let alias: &RoomAliasId = &RoomAliasId::parse(room).expect("valid RoomAliasId"); - Ok(self.resolve_alias(alias, None).await?.0) + let alias = &RoomAliasId::parse(room).expect("valid RoomAliasId"); + self.resolve_alias(alias, servers).await } } #[tracing::instrument(skip(self), name = "resolve")] pub async fn resolve_alias( - &self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, - ) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { - if !self - .services - .globals - .server_is_ours(room_alias.server_name()) - && (!servers + &self, room_alias: &RoomAliasId, servers: Option<Vec<OwnedServerName>>, + ) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> { + let server_name = room_alias.server_name(); + let server_is_ours = self.services.globals.server_is_ours(server_name); + let servers_contains_ours = || { + servers .as_ref() - .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned())) - || servers.as_ref().is_none()) - { - return self.remote_resolve(room_alias, servers).await; + .is_some_and(|servers| servers.contains(&self.services.globals.config.server_name)) + }; + + if !server_is_ours && !servers_contains_ours() { + return self + .remote_resolve(room_alias, servers.unwrap_or_default()) + .await; } - let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? { - Some(r) => Some(r), - None => self.resolve_appservice_alias(room_alias).await?, + let room_id = match self.resolve_local_alias(room_alias).await { + Ok(r) => Some(r), + Err(_) => self.resolve_appservice_alias(room_alias).await?, }; room_id.map_or_else( - || Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")), - |room_id| Ok((room_id, None)), + || Err!(Request(NotFound("Room with alias not found."))), + |room_id| Ok((room_id, Vec::new())), ) } #[tracing::instrument(skip(self), level = "debug")] - pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { - self.db.resolve_local_alias(alias) + pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<OwnedRoomId> { + self.db.alias_roomid.get(alias.alias()).await.deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> { - self.db.local_aliases_for_room(room_id) + pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .aliasid_alias + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, alias): (Ignore, &RoomAliasId)| alias) } #[tracing::instrument(skip(self), level = "debug")] - pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { - self.db.all_local_aliases() + pub fn all_local_aliases<'a>(&'a self) -> impl Stream<Item = (&RoomId, &str)> + Send + 'a { + self.db + .alias_roomid + .stream() + .ignore_err() + .map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart)) } async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<bool> { - let Some(room_id) = self.resolve_local_alias(alias)? else { - return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); - }; + let room_id = self + .resolve_local_alias(alias) + .await + .map_err(|_| err!(Request(NotFound("Alias not found."))))?; let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self - .db - .who_created_alias(alias)? - .is_some_and(|user| user == user_id) + .who_created_alias(alias).await + .is_ok_and(|user| user == user_id) // Server admins can remove any local alias - || self.services.admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { - Ok(true) - // Checking whether the user is able to change canonical aliases of the - // room - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? + return Ok(true); + } + + // Checking whether the user is able to change canonical aliases of the room + if let Ok(content) = self + .services + .state_accessor + .room_state_get_content::<RoomPowerLevelsEventContent>(&room_id, &StateEventType::RoomPowerLevels, "") + .await { - serde_json::from_str(event.content.get()) - .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) - .map(|content: RoomPowerLevelsEventContent| { - RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias) - }) + return Ok(RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomCanonicalAlias)); + } + // If there is no power levels event, only the room creator can change // canonical aliases - } else if let Some(event) = - self.services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + if let Ok(event) = self + .services + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "") + .await { - Ok(event.sender == user_id) - } else { - Err(Error::bad_database("Room has no m.room.create event")) + return Ok(event.sender == user_id); } + + Err!(Database("Room has no m.room.create event")) + } + + async fn who_created_alias(&self, alias: &RoomAliasId) -> Result<OwnedUserId> { + self.db.alias_userid.get(alias.alias()).await.deserialized() } async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { @@ -185,10 +249,11 @@ async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Opt .await, Ok(Some(_opt_result)) ) { - return Ok(Some( - self.resolve_local_alias(room_alias)? - .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, - )); + return self + .resolve_local_alias(room_alias) + .await + .map_err(|_| err!(Request(NotFound("Room does not exist.")))) + .map(Some); } } diff --git a/src/service/rooms/alias/remote.rs b/src/service/rooms/alias/remote.rs index 5d835240b26e9e8337738268e10c4cc75bf4d3e4..d9acccc9c6aa7e1882a12c8ec55a13db330ea2fb 100644 --- a/src/service/rooms/alias/remote.rs +++ b/src/service/rooms/alias/remote.rs @@ -1,75 +1,67 @@ -use conduit::{debug, debug_warn, Error, Result}; -use ruma::{ - api::{client::error::ErrorKind, federation}, - OwnedRoomId, OwnedServerName, RoomAliasId, -}; +use std::iter::once; -impl super::Service { - pub(super) async fn remote_resolve( - &self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, - ) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { - debug!(?room_alias, ?servers, "resolve"); +use conduit::{debug, debug_error, err, implement, Result}; +use federation::query::get_room_information::v1::Response; +use ruma::{api::federation, OwnedRoomId, OwnedServerName, RoomAliasId, ServerName}; - let mut response = self - .services - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; +#[implement(super::Service)] +pub(super) async fn remote_resolve( + &self, room_alias: &RoomAliasId, servers: Vec<OwnedServerName>, +) -> Result<(OwnedRoomId, Vec<OwnedServerName>)> { + debug!(?room_alias, servers = ?servers, "resolve"); + let servers = once(room_alias.server_name()) + .map(ToOwned::to_owned) + .chain(servers.into_iter()); - debug!("room alias server_name get_alias_helper response: {response:?}"); + let mut resolved_servers = Vec::new(); + let mut resolved_room_id: Option<OwnedRoomId> = None; + for server in servers { + match self.remote_request(room_alias, &server).await { + Err(e) => debug_error!("Failed to query for {room_alias:?} from {server}: {e}"), + Ok(Response { + room_id, + servers, + }) => { + debug!("Server {server} answered with {room_id:?} for {room_alias:?} servers: {servers:?}"); - if let Err(ref e) = response { - debug_warn!( - "Server {} of the original room alias failed to assist in resolving room alias: {e}", - room_alias.server_name(), - ); - } - - if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { - if let Some(servers) = servers { - for server in servers { - response = self - .services - .sending - .send_federation_request( - server, - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; - debug!("Got response from server {server} for room aliases: {response:?}"); + resolved_room_id.get_or_insert(room_id); + add_server(&mut resolved_servers, server); - if let Ok(ref response) = response { - if !response.servers.is_empty() { - break; - } - debug_warn!( - "Server {server} responded with room aliases, but was empty? Response: {response:?}" - ); - } + if !servers.is_empty() { + add_servers(&mut resolved_servers, servers); + break; } - } + }, } + } - if let Ok(response) = response { - let room_id = response.room_id; + resolved_room_id + .map(|room_id| (room_id, resolved_servers)) + .ok_or_else(|| err!(Request(NotFound("No servers could assist in resolving the room alias")))) +} - let mut pre_servers = response.servers; - // since the room alis server responded, insert it into the list - pre_servers.push(room_alias.server_name().into()); +#[implement(super::Service)] +async fn remote_request(&self, room_alias: &RoomAliasId, server: &ServerName) -> Result<Response> { + use federation::query::get_room_information::v1::Request; - return Ok((room_id, Some(pre_servers))); - } + let request = Request { + room_alias: room_alias.to_owned(), + }; + + self.services + .sending + .send_federation_request(server, request) + .await +} + +fn add_servers(servers: &mut Vec<OwnedServerName>, new: Vec<OwnedServerName>) { + for server in new { + add_server(servers, server); + } +} - Err(Error::BadRequest( - ErrorKind::NotFound, - "No servers could assist in resolving the room alias", - )) +fn add_server(servers: &mut Vec<OwnedServerName>, server: OwnedServerName) { + if !servers.contains(&server) { + servers.push(server); } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 6e7c783594903002f8f4473b92d876930bf3374e..3c36928afc9a4e34b089d2f5ea89c0d4988c16fb 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,13 +3,15 @@ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result}; +use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; +use crate::rooms::short::ShortEventId; + pub(super) struct Data { shorteventid_authchain: Arc<Map>, - pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>, + pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[ShortEventId]>>>, } impl Data { @@ -24,57 +26,63 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); + if let Some(result) = self + .auth_chain_cache + .lock() + .expect("cache locked") + .get_mut(key) + { + return Ok(Arc::clone(result)); } // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::<u64>()) - .map(utils::u64_from_u8) - .collect::<Arc<[u64]>>() - }); + if key.len() != 1 { + return Err!(Request(NotFound("auth_chain not cached"))); + } - if let Some(chain) = chain { - // Cache in RAM - self.auth_chain_cache - .lock() - .expect("locked") - .insert(vec![key[0]], Arc::clone(&chain)); + // Check database + let chain = self + .shorteventid_authchain + .qry(&key[0]) + .await + .map_err(|_| err!(Request(NotFound("auth_chain not found"))))?; - return Ok(Some(chain)); - } - } + let chain = chain + .chunks_exact(size_of::<u64>()) + .map(utils::u64_from_u8) + .collect::<Arc<[u64]>>(); - Ok(None) + // Cache in RAM + self.auth_chain_cache + .lock() + .expect("cache locked") + .insert(vec![key[0]], Arc::clone(&chain)); + + Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) -> Result<()> { + pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[ShortEventId]>) { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Only persist single events in db if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::<Vec<u8>>(), - )?; + let key = key[0].to_be_bytes(); + let val = auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::<Vec<u8>>(); + + self.shorteventid_authchain.insert(&key, &val); } // Cache in RAM self.auth_chain_cache .lock() - .expect("locked") + .expect("cache locked") .insert(key, auth_chain); - - Ok(()) } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 9a1e7e67af3b2f9ca1148fc6c402a61ace2a669c..cabb6f0cad9b2a3080a0ae5750851d0b17c1ef88 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,11 +5,12 @@ sync::Arc, }; -use conduit::{debug, error, trace, validated, warn, Err, Result}; +use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; +use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId}; use self::data::Data; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortEventId, Dep}; pub struct Service { services: Services, @@ -36,37 +37,49 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub async fn event_ids_iter<'a>( - &'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, - ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { - let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); - for starting_event in &starting_events_ { - starting_events.push(starting_event); - } - - Ok(self - .get_auth_chain(room_id, &starting_events) + pub async fn event_ids_iter( + &self, room_id: &RoomId, starting_events: &[&EventId], + ) -> Result<impl Stream<Item = Arc<EventId>> + Send + '_> { + let stream = self + .get_event_ids(room_id, starting_events) .await? .into_iter() - .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) + .stream(); + + Ok(stream) + } + + pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<Arc<EventId>>> { + let chain = self.get_auth_chain(room_id, starting_events).await?; + let event_ids = self + .services + .short + .multi_get_eventid_from_short(&chain) + .await + .into_iter() + .filter_map(Result::ok) + .collect(); + + Ok(event_ids) } #[tracing::instrument(skip_all, name = "auth_chain")] - pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<u64>> { + pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<ShortEventId>> { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); let started = std::time::Instant::now(); - let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, &short) in self + let mut starting_ids = self .services .short - .multi_get_or_create_shorteventid(starting_events)? - .iter() + .multi_get_or_create_shorteventid(starting_events) .enumerate() - { + .boxed(); + + let mut buckets = [BUCKET; NUM_BUCKETS]; + while let Some((i, short)) = starting_ids.next().await { let bucket: usize = short.try_into()?; - let bucket: usize = validated!(bucket % NUM_BUCKETS)?; + let bucket: usize = validated!(bucket % NUM_BUCKETS); buckets[bucket].insert((short, starting_events[i])); } @@ -84,8 +97,8 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId continue; } - let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? { + let chunk_key: Vec<ShortEventId> = chunk.iter().map(|(short, _)| short).copied().collect(); + if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -96,13 +109,13 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? { + if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { - let auth_chain = self.get_auth_chain_inner(room_id, event_id)?; - self.cache_auth_chain(vec![sevent_id], &auth_chain)?; + let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; + self.cache_auth_chain(vec![sevent_id], &auth_chain); chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); debug!( @@ -117,7 +130,7 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId chunk_cache.sort_unstable(); chunk_cache.dedup(); - self.cache_auth_chain_vec(chunk_key, &chunk_cache)?; + self.cache_auth_chain_vec(chunk_key, &chunk_cache); full_auth_chain.extend(chunk_cache.iter()); misses = misses.saturating_add(1); debug!( @@ -143,24 +156,31 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId } #[tracing::instrument(skip(self, room_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<ShortEventId>> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match self.services.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { + match self.services.timeline.get_pdu(&event_id).await { + Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"), + Ok(pdu) => { if pdu.room_id != room_id { - return Err!(Request(Forbidden( - "auth event {event_id:?} for incorrect room {} which is not {}", - pdu.room_id, - room_id - ))); + return Err!(Request(Forbidden(error!( + ?event_id, + ?room_id, + wrong_room_id = ?pdu.room_id, + "auth event for incorrect room" + )))); } + for auth_event in &pdu.auth_events { - let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; + let sauthevent = self + .services + .short + .get_or_create_shorteventid(auth_event) + .await; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); @@ -168,32 +188,27 @@ fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<H } } }, - Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); - }, - Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); - }, } } Ok(found) } - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> { - self.db.get_cached_eventid_authchain(key) + #[inline] + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> { + self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) + pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<ShortEventId>) { + let val = auth_chain.iter().copied().collect::<Arc<[ShortEventId]>>(); + self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) + pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<ShortEventId>) { + let val = auth_chain.iter().copied().collect::<Arc<[ShortEventId]>>(); + self.db.cache_auth_chain(key, val); } pub fn get_cache_usage(&self) -> (usize, usize) { diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs deleted file mode 100644 index 713ee05762cbb597284462e59a4670b7c35e8bd9..0000000000000000000000000000000000000000 --- a/src/service/rooms/directory/data.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{OwnedRoomId, RoomId}; - -pub(super) struct Data { - publicroomids: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - publicroomids: db["publicroomids"].clone(), - } - } - - pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[]) - } - - pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes()) - } - - pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - pub(super) fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } -} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 706e6c2e5072554abebcb8705ad2a6815cd5aa1e..63ed3519fdaff18420fc51dadbd7baadb8e08f52 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,36 +1,47 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; - -use self::data::Data; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::Stream; +use ruma::{api::client::room::Visibility, RoomId}; pub struct Service { db: Data, } +struct Data { + publicroomids: Arc<Map>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + publicroomids: args.db["publicroomids"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } +#[implement(Service)] +pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id, []); } + +#[implement(Service)] +pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id); } - #[tracing::instrument(skip(self), level = "debug")] - pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } +#[implement(Service)] +pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send { self.db.publicroomids.keys().ignore_err() } - #[tracing::instrument(skip(self), level = "debug")] - pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) } +#[implement(Service)] +pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.visibility(room_id).await == Visibility::Public } - #[tracing::instrument(skip(self), level = "debug")] - pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() } +#[implement(Service)] +pub async fn visibility(&self, room_id: &RoomId) -> Visibility { + if self.db.publicroomids.get(room_id).await.is_ok() { + Visibility::Public + } else { + Visibility::Private + } } diff --git a/src/service/rooms/event_handler/acl_check.rs b/src/service/rooms/event_handler/acl_check.rs new file mode 100644 index 0000000000000000000000000000000000000000..f2ff1b0034d196d6186908309c697e2fdf5ee774 --- /dev/null +++ b/src/service/rooms/event_handler/acl_check.rs @@ -0,0 +1,35 @@ +use conduit::{debug, implement, trace, warn, Err, Result}; +use ruma::{ + events::{room::server_acl::RoomServerAclEventContent, StateEventType}, + RoomId, ServerName, +}; + +/// Returns Ok if the acl allows the server +#[implement(super::Service)] +#[tracing::instrument(skip_all)] +pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result { + let Ok(acl_event_content) = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomServerAcl, "") + .await + .map(|c: RoomServerAclEventContent| c) + .inspect(|acl| trace!("ACL content found: {acl:?}")) + .inspect_err(|e| trace!("No ACL content found: {e:?}")) + else { + return Ok(()); + }; + + if acl_event_content.allow.is_empty() { + warn!("Ignoring broken ACL event (allow key is empty)"); + return Ok(()); + } + + if acl_event_content.is_allowed(server_name) { + trace!("server {server_name} is allowed by ACL"); + Ok(()) + } else { + debug!("Server {server_name} was denied by room ACL in {room_id}"); + Err!(Request(Forbidden("Server was denied by room ACL"))) + } +} diff --git a/src/service/rooms/event_handler/fetch_and_handle_outliers.rs b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs new file mode 100644 index 0000000000000000000000000000000000000000..677b78f21e8b3d1327f1d7d83448a635ea5b55c7 --- /dev/null +++ b/src/service/rooms/event_handler/fetch_and_handle_outliers.rs @@ -0,0 +1,181 @@ +use std::{ + collections::{hash_map, BTreeMap, HashSet}, + sync::Arc, + time::Instant, +}; + +use conduit::{ + debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent, +}; +use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName}; + +/// Find the event and auth it. Once the event is validated (steps 1 - 8) +/// it is appended to the outliers Tree. +/// +/// Returns pdu and if we fetched it over federation the raw json. +/// +/// a. Look in the main timeline (pduid_pdu tree) +/// b. Look at outlier pdu tree +/// c. Ask origin server over federation +/// d. TODO: Ask other servers over federation? +#[implement(super::Service)] +pub(super) async fn fetch_and_handle_outliers<'a>( + &self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId, + room_version_id: &'a RoomVersionId, +) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> { + let back_off = |id| match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), + }; + + let mut events_with_auth_events = Vec::with_capacity(events.len()); + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { + trace!("Found {id} in db"); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); + let mut events_all = HashSet::with_capacity(todo_auth_events.len()); + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); + continue; + } + } + + if events_all.contains(&next_id) { + continue; + } + + if self.services.timeline.pdu_exists(&next_id).await { + trace!("Found {next_id} in db"); + continue; + } + + debug!("Fetching {next_id} over federation."); + match self + .services + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + include_unredacted_content: None, + }, + ) + .await + { + Ok(res) => { + debug!("Got {next_id} over federation"); + let Ok((calculated_event_id, value)) = pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) + else { + back_off((*next_id).to_owned()); + continue; + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {next_id}, we got \ + {calculated_event_id}. Event: {:?}", + &res.pdu + ); + } + + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc<EventId> = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } + } + } else { + warn!("Auth event list invalid"); + } + + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + debug_error!("Failed to fetch event {next_id}: {e}"); + back_off((*next_id).to_owned()); + }, + } + } + events_with_auth_events.push((id, None, events_in_reverse_order)); + } + + let mut pdus = Vec::with_capacity(events_with_auth_events.len()); + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {id} in db"); + pdus.push((local_pdu.clone(), None)); + } + + for (next_id, value) in events_in_reverse_order.into_iter().rev() { + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(&*next_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); + continue; + } + } + + match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true)).await + { + Ok((pdu, json)) => { + if next_id == *id { + pdus.push((pdu, Some(json))); + } + }, + Err(e) => { + warn!("Authentication of event {next_id} failed: {e:?}"); + back_off(next_id.into()); + }, + } + } + } + pdus +} diff --git a/src/service/rooms/event_handler/fetch_prev.rs b/src/service/rooms/event_handler/fetch_prev.rs new file mode 100644 index 0000000000000000000000000000000000000000..4acdba1dc681f5799b63718e03882b32cac49941 --- /dev/null +++ b/src/service/rooms/event_handler/fetch_prev.rs @@ -0,0 +1,104 @@ +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug_warn, err, implement, PduEvent, Result}; +use futures::{future, FutureExt}; +use ruma::{ + int, + state_res::{self}, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, +}; + +use super::check_room_id; + +#[implement(super::Service)] +#[allow(clippy::type_complexity)] +#[tracing::instrument(skip_all)] +pub(super) async fn fetch_prev( + &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, + initial_set: Vec<Arc<EventId>>, +) -> Result<( + Vec<Arc<EventId>>, + HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>, +)> { + let mut graph: HashMap<Arc<EventId>, _> = HashMap::with_capacity(initial_set.len()); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; + + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; + + let mut amount = 0; + + while let Some(prev_event_id) = todo_outlier_stack.pop() { + self.services.server.check_running()?; + + if let Some((pdu, mut json_opt)) = self + .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) + .boxed() + .await + .pop() + { + check_room_id(room_id, &pdu)?; + + let limit = self.services.server.config.max_fetch_prev_events; + if amount > limit { + debug_warn!("Max prev event limit reached! Limit: {limit}"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if json_opt.is_none() { + json_opt = self + .services + .outlier + .get_outlier_pdu_json(&prev_event_id) + .await + .ok(); + } + + if let Some(json) = json_opt { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + amount = amount.saturating_add(1); + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(prev_prev.clone()); + } + } + + graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect()); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + // Get json failed, so this was not fetched over federation + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } + + let event_fetch = |event_id| { + let origin_server_ts = eventid_info + .get(&event_id) + .cloned() + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts); + + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + future::ok((int!(0), MilliSecondsSinceUnixEpoch(origin_server_ts))) + }; + + let sorted = state_res::lexicographical_topological_sort(&graph, &event_fetch) + .await + .map_err(|e| err!(Database(error!("Error sorting prev events: {e}"))))?; + + Ok((sorted, eventid_info)) +} diff --git a/src/service/rooms/event_handler/fetch_state.rs b/src/service/rooms/event_handler/fetch_state.rs new file mode 100644 index 0000000000000000000000000000000000000000..74b0bb32a582a6fa07c18b347a4147c6a08ee143 --- /dev/null +++ b/src/service/rooms/event_handler/fetch_state.rs @@ -0,0 +1,84 @@ +use std::{ + collections::{hash_map, HashMap}, + sync::Arc, +}; + +use conduit::{debug, implement, warn, Err, Error, PduEvent, Result}; +use futures::FutureExt; +use ruma::{ + api::federation::event::get_room_state_ids, events::StateEventType, EventId, RoomId, RoomVersionId, ServerName, +}; + +/// Call /state_ids to find out what the state at this pdu is. We trust the +/// server's response to some extend (sic), but we still do a lot of checks +/// on the events +#[implement(super::Service)] +#[tracing::instrument(skip(self, create_event, room_version_id))] +pub(super) async fn fetch_state( + &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, + event_id: &EventId, +) -> Result<Option<HashMap<u64, Arc<EventId>>>> { + debug!("Fetching state ids"); + let res = self + .services + .sending + .send_synapse_request( + origin, + get_room_state_ids::v1::Request { + room_id: room_id.to_owned(), + event_id: (*event_id).to_owned(), + }, + ) + .await + .inspect_err(|e| warn!("Fetching state for event failed: {e}"))?; + + debug!("Fetching state events"); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::<Vec<_>>(); + + let state_vec = self + .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id) + .boxed() + .await; + + let mut state: HashMap<_, Arc<EventId>> = HashMap::with_capacity(state_vec.len()); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) + .await; + + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } + } + + // The original create event must still be in the state + let create_shortstatekey = self + .services + .short + .get_shortstatekey(&StateEventType::RoomCreate, "") + .await?; + + if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { + return Err!(Database("Incoming event refers to wrong create event.")); + } + + Ok(Some(state)) +} diff --git a/src/service/rooms/event_handler/handle_incoming_pdu.rs b/src/service/rooms/event_handler/handle_incoming_pdu.rs new file mode 100644 index 0000000000000000000000000000000000000000..4d2d75d5f70d9459bcb895516195d50c7e117f51 --- /dev/null +++ b/src/service/rooms/event_handler/handle_incoming_pdu.rs @@ -0,0 +1,172 @@ +use std::{ + collections::{hash_map, BTreeMap}, + time::Instant, +}; + +use conduit::{debug, err, implement, warn, Error, Result}; +use futures::FutureExt; +use ruma::{ + api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId, +}; + +use super::{check_room_id, get_room_version_id}; +use crate::rooms::timeline::RawPduId; + +/// When receiving an event one needs to: +/// 0. Check the server is in the room +/// 1. Skip the PDU if we already know about it +/// 1.1. Remove unsigned field +/// 2. Check signatures, otherwise drop +/// 3. Check content hash, redact if doesn't match +/// 4. Fetch any missing auth events doing all checks listed here starting at 1. +/// These are not timeline events +/// 5. Reject "due to auth events" if can't get all the auth events or some of +/// the auth events are also rejected "due to auth events" +/// 6. Reject "due to auth events" if the event doesn't pass auth based on the +/// auth events +/// 7. Persist this event as an outlier +/// 8. If not timeline event: stop +/// 9. Fetch any missing prev events doing all checks listed here starting at 1. +/// These are timeline events +/// 10. Fetch missing state and auth chain events by calling `/state_ids` at +/// backwards extremities doing all the checks in this list starting at +/// 1. These are not timeline events +/// 11. Check the auth of the event passes based on the state of the event +/// 12. Ensure that the state is derived from the previous current state (i.e. +/// we calculated by doing state res where one of the inputs was a +/// previously trusted set of state, don't just trust a set of state we got +/// from a remote) +/// 13. Use state resolution to find new room state +/// 14. Check if the event passes auth based on the "current state" of the room, +/// if not soft fail it +#[implement(super::Service)] +#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")] +pub async fn handle_incoming_pdu<'a>( + &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, + value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool, +) -> Result<Option<RawPduId>> { + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { + return Ok(Some(pdu_id)); + } + + // 1.1 Check the server is in the room + if !self.services.metadata.exists(room_id).await { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); + } + + // 1.2 Check if the room is disabled + if self.services.metadata.is_disabled(room_id).await { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Federation of this room is currently disabled on this server.", + )); + } + + // 1.3.1 Check room ACL on origin field/server + self.acl_check(origin, room_id).await?; + + // 1.3.2 Check room ACL on sender's server name + let sender: &UserId = value + .get("sender") + .try_into() + .map_err(|e| err!(Request(InvalidParam("PDU does not have a valid sender key: {e}"))))?; + + self.acl_check(sender.server_name(), room_id).await?; + + // Fetch create event + let create_event = self + .services + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await?; + + // Procure the room version + let room_version_id = get_room_version_id(&create_event)?; + + let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?; + + let (incoming_pdu, val) = self + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false) + .boxed() + .await?; + + check_room_id(room_id, &incoming_pdu)?; + + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); + } + // Skip old events + if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. + // These are timeline events + let (sorted_prev_events, mut eventid_info) = self + .fetch_prev( + origin, + &create_event, + room_id, + &room_version_id, + incoming_pdu.prev_events.clone(), + ) + .await?; + + debug!(events = ?sorted_prev_events, "Got previous events"); + for prev_id in sorted_prev_events { + self.services.server.check_running()?; + if let Err(e) = self + .handle_prev_pdu( + origin, + event_id, + room_id, + &mut eventid_info, + &create_event, + &first_pdu_in_room, + &prev_id, + ) + .await + { + use hash_map::Entry; + + let now = Instant::now(); + warn!("Prev event {prev_id} failed: {e}"); + + match self + .services + .globals + .bad_event_ratelimiter + .write() + .expect("locked") + .entry(prev_id.into()) + { + Entry::Vacant(e) => { + e.insert((now, 1)); + }, + Entry::Occupied(mut e) => { + *e.get_mut() = (now, e.get().1.saturating_add(1)); + }, + }; + } + } + + // Done with prev events, now handling the incoming event + let start_time = Instant::now(); + self.federation_handletime + .write() + .expect("locked") + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + + let r = self + .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id) + .await; + + self.federation_handletime + .write() + .expect("locked") + .remove(&room_id.to_owned()); + + r +} diff --git a/src/service/rooms/event_handler/handle_outlier_pdu.rs b/src/service/rooms/event_handler/handle_outlier_pdu.rs new file mode 100644 index 0000000000000000000000000000000000000000..2d95ff6379ec7675825b72567a85f5fa02b049b7 --- /dev/null +++ b/src/service/rooms/event_handler/handle_outlier_pdu.rs @@ -0,0 +1,164 @@ +use std::{ + collections::{hash_map, BTreeMap, HashMap}, + sync::Arc, +}; + +use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; +use futures::future::ready; +use ruma::{ + api::client::error::ErrorKind, + events::StateEventType, + state_res::{self, EventTypeExt}, + CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, +}; + +use super::{check_room_id, get_room_version_id, to_room_version}; + +#[implement(super::Service)] +#[allow(clippy::too_many_arguments)] +pub(super) async fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + mut value: CanonicalJsonObject, auth_events_known: bool, +) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> { + // 1. Remove unsigned field + value.remove("unsigned"); + + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let room_version_id = get_room_version_id(create_event)?; + let mut val = match self + .services + .server_keys + .verify_event(&value, Some(&room_version_id)) + .await + { + Ok(ruma::signatures::Verified::All) => value, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + debug_info!("Calculated hash does not match (redaction): {event_id}"); + let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { + return Err!(Request(InvalidParam("Redaction failed"))); + }; + + // Skip the PDU if it is redacted and we already have it as an outlier event + if self.services.timeline.pdu_exists(event_id).await { + return Err!(Request(InvalidParam("Event was redacted and we already knew about it"))); + } + + obj + }, + Err(e) => { + return Err!(Request(InvalidParam(debug_error!( + "Signature verification failed for {event_id}: {e}" + )))) + }, + }; + + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = + serde_json::from_value::<PduEvent>(serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue")) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + + check_room_id(room_id, &incoming_pdu)?; + + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!("Fetching auth events"); + Box::pin( + self.fetch_and_handle_outliers( + origin, + &incoming_pdu + .auth_events + .iter() + .map(|x| Arc::from(&**x)) + .collect::<Vec<Arc<EventId>>>(), + create_event, + room_id, + &room_version_id, + ), + ) + .await; + } + + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Checking based on auth events"); + // Build map of auth events + let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); + for id in &incoming_pdu.auth_events { + let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { + warn!("Could not find auth event {id}"); + continue; + }; + + check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, + } + } + + // The original create event must be in the auth events + if !matches!( + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } + + let state_fetch = |ty: &'static StateEventType, sk: &str| { + let key = ty.with_state_key(sk); + ready(auth_events.get(&key)) + }; + + let auth_check = state_res::event_auth::auth_check( + &to_room_version(&room_version_id), + &incoming_pdu, + None, // TODO: third party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + if !auth_check { + return Err!(Request(Forbidden("Auth check failed"))); + } + + trace!("Validation successful."); + + // 7. Persist the event as an outlier. + self.services + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val); + + trace!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) +} diff --git a/src/service/rooms/event_handler/handle_prev_pdu.rs b/src/service/rooms/event_handler/handle_prev_pdu.rs new file mode 100644 index 0000000000000000000000000000000000000000..90ff7f06b081e17f9b18d84f9e9b0fdabce291ab --- /dev/null +++ b/src/service/rooms/event_handler/handle_prev_pdu.rs @@ -0,0 +1,82 @@ +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, + time::Instant, +}; + +use conduit::{debug, implement, utils::math::continue_exponential_backoff_secs, Error, PduEvent, Result}; +use ruma::{api::client::error::ErrorKind, CanonicalJsonValue, EventId, RoomId, ServerName}; + +#[implement(super::Service)] +#[allow(clippy::type_complexity)] +#[allow(clippy::too_many_arguments)] +#[tracing::instrument( + skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room), + name = "prev" +)] +pub(super) async fn handle_prev_pdu<'a>( + &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, + eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>, + create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId, +) -> Result { + // Check for disabled again because it might have changed + if self.services.metadata.is_disabled(room_id).await { + debug!( + "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event \ + ID {event_id}" + ); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Federation of this room is currently disabled on this server.", + )); + } + + if let Some((time, tries)) = self + .services + .globals + .bad_event_ratelimiter + .read() + .expect("locked") + .get(prev_id) + { + // Exponential backoff + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!( + ?tries, + duration = ?time.elapsed(), + "Backing off from prev_event" + ); + return Ok(()); + } + } + + if let Some((pdu, json)) = eventid_info.remove(prev_id) { + // Skip old events + if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(()); + } + + let start_time = Instant::now(); + self.federation_handletime + .write() + .expect("locked") + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + + self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id) + .await?; + + self.federation_handletime + .write() + .expect("locked") + .remove(&room_id.to_owned()); + + debug!( + elapsed = ?start_time.elapsed(), + "Handled prev_event", + ); + } + + Ok(()) +} diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index bee986deb9a1cc57ef27115a05ae2b30e4045287..f6440fe936567699365c1ab2dd68cb18d7ccd130 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,45 +1,34 @@ +mod acl_check; +mod fetch_and_handle_outliers; +mod fetch_prev; +mod fetch_state; +mod handle_incoming_pdu; +mod handle_outlier_pdu; +mod handle_prev_pdu; mod parse_incoming_pdu; +mod resolve_state; +mod state_at_incoming; +mod upgrade_outlier_pdu; use std::{ - collections::{hash_map, BTreeMap, HashMap, HashSet}, + collections::HashMap, fmt::Write, - pin::Pin, sync::{Arc, RwLock as StdRwLock}, time::Instant, }; -use conduit::{ - debug, debug_error, debug_info, err, error, info, pdu, trace, - utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, PduEvent, Result, -}; -use futures_util::Future; +use conduit::{utils::MutexMap, Err, PduEvent, Result, Server}; use ruma::{ - api::{ - client::error::ErrorKind, - federation::event::{get_event, get_room_state_ids}, - }, - events::{ - room::{ - create::RoomCreateEventContent, redaction::RoomRedactionEventContent, server_acl::RoomServerAclEventContent, - }, - StateEventType, TimelineEventType, - }, - int, - serde::Base64, - state_res::{self, RoomVersion, StateMap}, - uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, - RoomVersionId, ServerName, + events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId, + RoomVersionId, }; -use tokio::sync::RwLock; -use super::state_compressor::CompressedStateEvent; use crate::{globals, rooms, sending, server_keys, Dep}; pub struct Service { - services: Services, - pub federation_handletime: StdRwLock<HandleTimeMap>, pub mutex_federation: RoomMutexMap, + pub federation_handletime: StdRwLock<HandleTimeMap>, + services: Services, } struct Services { @@ -55,22 +44,17 @@ struct Services { state_accessor: Dep<rooms::state_accessor::Service>, state_compressor: Dep<rooms::state_compressor::Service>, timeline: Dep<rooms::timeline::Service>, + server: Arc<Server>, } type RoomMutexMap = MutexMap<OwnedRoomId, ()>; type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>; -// We use some AsyncRecursiveType hacks here so we can call async funtion -// recursively. -type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>; -type AsyncRecursiveCanonicalJsonVec<'a> = - AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>>; -type AsyncRecursiveCanonicalJsonResult<'a> = - AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + mutex_federation: RoomMutexMap::new(), + federation_handletime: HandleTimeMap::new().into(), services: Services { globals: args.depend::<globals::Service>("globals"), sending: args.depend::<sending::Service>("sending"), @@ -84,9 +68,8 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + server: args.server.clone(), }, - federation_handletime: HandleTimeMap::new().into(), - mutex_federation: RoomMutexMap::new(), })) } @@ -108,1304 +91,34 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - /// When receiving an event one needs to: - /// 0. Check the server is in the room - /// 1. Skip the PDU if we already know about it - /// 1.1. Remove unsigned field - /// 2. Check signatures, otherwise drop - /// 3. Check content hash, redact if doesn't match - /// 4. Fetch any missing auth events doing all checks listed here starting - /// at 1. These are not timeline events - /// 5. Reject "due to auth events" if can't get all the auth events or some - /// of the auth events are also rejected "due to auth events" - /// 6. Reject "due to auth events" if the event doesn't pass auth based on - /// the auth events - /// 7. Persist this event as an outlier - /// 8. If not timeline event: stop - /// 9. Fetch any missing prev events doing all checks listed here starting - /// at 1. These are timeline events - /// 10. Fetch missing state and auth chain events by calling `/state_ids` at - /// backwards extremities doing all the checks in this list starting at - /// 1. These are not timeline events - /// 11. Check the auth of the event passes based on the state of the event - /// 12. Ensure that the state is derived from the previous current state - /// (i.e. we calculated by doing state res where one of the inputs was a - /// previously trusted set of state, don't just trust a set of state we - /// got from a remote) - /// 13. Use state resolution to find new room state - /// 14. Check if the event passes auth based on the "current state" of the - /// room, if not soft fail it - #[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")] - pub async fn handle_incoming_pdu<'a>( - &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, - value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<Option<Vec<u8>>> { - // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { - return Ok(Some(pdu_id.to_vec())); - } - - // 1.1 Check the server is in the room - if !self.services.metadata.exists(room_id)? { - return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); - } - - // 1.2 Check if the room is disabled - if self.services.metadata.is_disabled(room_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Federation of this room is currently disabled on this server.", - )); - } - - // 1.3.1 Check room ACL on origin field/server - self.acl_check(origin, room_id)?; - - // 1.3.2 Check room ACL on sender's server name - let sender: OwnedUserId = serde_json::from_value( - value - .get("sender") - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "PDU does not have a sender key"))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid"))?; - - self.acl_check(sender.server_name(), room_id)?; - - // Fetch create event - let create_event = self - .services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; - - // Procure the room version - let room_version_id = Self::get_room_version_id(&create_event)?; - - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - - let (incoming_pdu, val) = self - .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) - .await?; - - Self::check_room_id(room_id, &incoming_pdu)?; - - // 8. if not timeline event: stop - if !is_timeline_event { - return Ok(None); - } - // Skip old events - if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(None); - } - - // 9. Fetch any missing prev events doing all checks listed here starting at 1. - // These are timeline events - let (sorted_prev_events, mut eventid_info) = self - .fetch_prev( - origin, - &create_event, - room_id, - &room_version_id, - pub_key_map, - incoming_pdu.prev_events.clone(), - ) - .await?; - - debug!(events = ?sorted_prev_events, "Got previous events"); - for prev_id in sorted_prev_events { - match self - .handle_prev_pdu( - origin, - event_id, - room_id, - pub_key_map, - &mut eventid_info, - &create_event, - &first_pdu_in_room, - &prev_id, - ) - .await - { - Ok(()) => continue, - Err(e) => { - warn!("Prev event {} failed: {}", prev_id, e); - match self - .services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); - }, - }; - }, - } - } - - // Done with prev events, now handling the incoming event - let start_time = Instant::now(); - self.federation_handletime - .write() - .expect("locked") - .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - - let r = self - .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) - .await; - - self.federation_handletime - .write() - .expect("locked") - .remove(&room_id.to_owned()); - - r - } - - #[allow(clippy::type_complexity)] - #[allow(clippy::too_many_arguments)] - #[tracing::instrument( - skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room), - name = "prev" - )] - pub async fn handle_prev_pdu<'a>( - &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>, - create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId, - ) -> Result<()> { - // Check for disabled again because it might have changed - if self.services.metadata.is_disabled(room_id)? { - debug!( - "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ - event ID {event_id}" - ); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Federation of this room is currently disabled on this server.", - )); - } - - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(prev_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!( - ?tries, - duration = ?time.elapsed(), - "Backing off from prev_event" - ); - return Ok(()); - } - } - - if let Some((pdu, json)) = eventid_info.remove(prev_id) { - // Skip old events - if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(()); - } - - let start_time = Instant::now(); - self.federation_handletime - .write() - .expect("locked") - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - - self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map) - .await?; - - self.federation_handletime - .write() - .expect("locked") - .remove(&room_id.to_owned()); - - debug!( - elapsed = ?start_time.elapsed(), - "Handled prev_event", - ); - } - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - fn handle_outlier_pdu<'a>( - &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, - mut value: BTreeMap<String, CanonicalJsonValue>, auth_events_known: bool, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> AsyncRecursiveCanonicalJsonResult<'a> { - Box::pin(async move { - // 1. Remove unsigned field - value.remove("unsigned"); - - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let room_version_id = Self::get_room_version_id(create_event)?; - - let guard = pub_key_map.read().await; - let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e,); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); - }, - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - debug_info!("Calculated hash does not match (redaction): {event_id}"); - let Ok(obj) = ruma::canonical_json::redact(value, &room_version_id, None) else { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")); - }; - - // Skip the PDU if it is redacted and we already have it as an outlier event - if self.services.timeline.get_pdu_json(event_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event was redacted and we already knew about it", - )); - } + async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await } - obj - }, - Ok(ruma::signatures::Verified::All) => value, - }; - - drop(guard); - - // Now that we have checked the signature and hashes we can add the eventID and - // convert to our PduEvent type - val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - let incoming_pdu = serde_json::from_value::<PduEvent>( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - - Self::check_room_id(room_id, &incoming_pdu)?; - - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. - // These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of - // the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!("Fetching auth events"); - self.fetch_and_handle_outliers( - origin, - &incoming_pdu - .auth_events - .iter() - .map(|x| Arc::from(&**x)) - .collect::<Vec<_>>(), - create_event, - room_id, - &room_version_id, - pub_key_map, - ) - .await; - } - - // 6. Reject "due to auth events" if the event doesn't pass auth based on the - // auth events - debug!("Checking based on auth events"); - // Build map of auth events - let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); - for id in &incoming_pdu.auth_events { - let Some(auth_event) = self.services.timeline.get_pdu(id)? else { - warn!("Could not find auth event {}", id); - continue; - }; - - Self::check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - }, - } - } - - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, String::new())) - .map(AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } - - if !state_res::event_auth::auth_check( - &Self::to_room_version(&room_version_id), - &incoming_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed"))? - { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); - } - - trace!("Validation successful."); - - // 7. Persist the event as an outlier. - self.services - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; - - trace!("Added pdu as outlier."); - - Ok((Arc::new(incoming_pdu), val)) - }) + async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> { + self.services.timeline.get_pdu(&event_id).await.ok() } +} - pub async fn upgrade_outlier_to_timeline_pdu( - &self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent, - origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<Option<Vec<u8>>> { - // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { - return Ok(Some(pduid.to_vec())); - } - - if self - .services - .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id)? - { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); - } - - debug!("Upgrading to timeline pdu"); - let timer = tokio::time::Instant::now(); - let room_version_id = Self::get_room_version_id(create_event)?; - - // 10. Fetch missing state and auth chain events by calling /state_ids at - // backwards extremities doing all the checks in this list starting at 1. - // These are not timeline events. - - debug!("Resolving state at event"); - let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 { - self.state_at_incoming_degree_one(&incoming_pdu).await? - } else { - self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id) - .await? - }; - - if state_at_incoming_event.is_none() { - state_at_incoming_event = self - .fetch_state( - origin, - create_event, - room_id, - &room_version_id, - pub_key_map, - &incoming_pdu.event_id, - ) - .await?; - } - - let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); - let room_version = Self::to_room_version(&room_version_id); - - debug!("Performing auth check"); - // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| { - self.services - .short - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) - }, - ) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; - - if !check_result { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Event has failed auth check with state at the event.", - )); - } - - debug!("Gathering auth events"); - let auth_events = self.services.state.get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - )?; - - // Soft fail check before doing state res - debug!("Performing soft-fail check"); - let soft_fail = { - use RoomVersionId::*; - - !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::<PduEvent>, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? - || incoming_pdu.kind == TimelineEventType::RoomRedaction - && match room_version_id { - V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { - if let Some(redact_id) = &incoming_pdu.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - _ => { - let content = serde_json::from_str::<RoomRedactionEventContent>(incoming_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; - - if let Some(redact_id) = &content.redacts { - !self.services.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - } - }; - - // 13. Use state resolution to find new room state - - // We start looking at current room state now, so lets lock the room - trace!("Locking the room"); - let state_lock = self.services.state.mutex.lock(room_id).await; - - // Now we calculate the set of extremities this room has after the incoming - // event has been applied. We start with the previous extremities (aka leaves) - trace!("Calculating extremities"); - let mut extremities = self.services.state.get_forward_extremities(room_id)?; - trace!("Calculated {} extremities", extremities.len()); - - // Remove any forward extremities that are referenced by this incoming event's - // prev_events - for prev_event in &incoming_pdu.prev_events { - extremities.remove(prev_event); - } - - // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); - debug!("Retained {} extremities. Compressing state", extremities.len()); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - self.services - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::<Result<_>>()?, - ); - - if incoming_pdu.state_key.is_some() { - debug!("Event is a state-event. Deriving new room state"); - - // We also add state after incoming event to the fork states - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; - - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); - } - - let new_room_state = self - .resolve_state(room_id, &room_version_id, state_after) - .await?; - - // Set the new room state to the resolved state - debug!("Forcing new room state"); - let (sstatehash, new, removed) = self - .services - .state_compressor - .save_state(room_id, new_room_state)?; - - self.services - .state - .force_state(room_id, sstatehash, new, removed, &state_lock) - .await?; - } - - // 14. Check if the event passes auth based on the "current state" of the room, - // if not soft fail it - if soft_fail { - debug!("Soft failing event"); - self.services - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.iter().map(|e| (**e).to_owned()).collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); - self.services - .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id)?; - - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); - } - - trace!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); - - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - let pdu_id = self - .services - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.iter().map(|e| (**e).to_owned()).collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Event has passed all auth/stateres checks - drop(state_lock); - debug_info!( - elapsed = ?timer.elapsed(), - "Accepted", - ); - - Ok(pdu_id) - } - - pub async fn resolve_state( - &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, - ) -> Result<Arc<HashSet<CompressedStateEvent>>> { - debug!("Loading current room state ids"); - let current_sstatehash = self - .services - .state - .get_room_shortstatehash(room_id)? - .expect("every room has state"); - - let current_state_ids = self - .services - .state_accessor - .state_full_ids(current_sstatehash) - .await?; - - let fork_states = [current_state_ids, incoming_state]; - - let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); - for state in &fork_states { - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) - .await? - .collect(), - ); - } - - debug!("Loading fork states"); - let fork_states: Vec<_> = fork_states - .into_iter() - .map(|map| { - map.into_iter() - .filter_map(|(k, id)| { - self.services - .short - .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .ok() - }) - .collect::<StateMap<_>>() - }) - .collect(); - - let lock = self.services.globals.stateres_mutex.lock(); - - debug!("Resolving state"); - let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); - - let state = match state_resolve { - Ok(new_state) => new_state, - Err(e) => { - error!("State resolution failed: {}", e); - return Err(Error::bad_database( - "State resolution failed, either an event could not be found or deserialization", - )); - }, - }; - - drop(lock); - - debug!("State resolution done. Compressing state"); - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - self.services - .state_compressor - .compress_state_event(shortstatekey, &event_id) - }) - .collect::<Result<_>>()?; - - Ok(Arc::new(new_room_state)) - } - - // TODO: if we know the prev_events of the incoming event we can avoid the - // request and build the state from a known point and resolve if > 1 prev_event - #[tracing::instrument(skip_all, name = "state")] - pub async fn state_at_incoming_degree_one( - &self, incoming_pdu: &Arc<PduEvent>, - ) -> Result<Option<HashMap<u64, Arc<EventId>>>> { - let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = self - .services - .state_accessor - .pdu_shortstatehash(prev_event)?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some( - self.services - .state_accessor - .state_full_ids(shortstatehash) - .await, - ) - } else { - None - }; - - if let Some(Ok(mut state)) = state { - debug!("Using cached state"); - let prev_pdu = self - .services - .timeline - .get_pdu(prev_event) - .ok() - .flatten() - .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; - - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; - - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } - - return Ok(Some(state)); - } - - Ok(None) - } - - #[tracing::instrument(skip_all, name = "state")] - pub async fn state_at_incoming_resolved( - &self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId, - ) -> Result<Option<HashMap<u64, Arc<EventId>>>> { - debug!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { - okay = false; - break; - }; - - let Ok(Some(sstatehash)) = self - .services - .state_accessor - .pdu_shortstatehash(prev_eventid) - else { - okay = false; - break; - }; - - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if !okay { - return Ok(None); - } - - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = self - .services - .state_accessor - .state_full_ids(sstatehash) - .await?; - - if let Some(state_key) = &prev_event.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); - // Now it's the state after the pdu - } - - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { - if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); - } - starting_events.push(id); - } - - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, starting_events) - .await? - .collect(), - ); - - fork_states.push(state); - } - - let lock = self.services.globals.stateres_mutex.lock(); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = self.services.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); - drop(lock); - - Ok(match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - Ok((shortstatekey, event_id)) - }) - .collect::<Result<_>>()?, - ), - Err(e) => { - warn!( - "State resolution on prev events failed, either an event could not be found or deserialization: {}", - e - ); - None - }, - }) - } - - /// Call /state_ids to find out what the state at this pdu is. We trust the - /// server's response to some extend (sic), but we still do a lot of checks - /// on the events - #[tracing::instrument(skip(self, pub_key_map, create_event, room_version_id))] - async fn fetch_state( - &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId, - ) -> Result<Option<HashMap<u64, Arc<EventId>>>> { - debug!("Fetching state ids"); - match self - .services - .sending - .send_federation_request( - origin, - get_room_state_ids::v1::Request { - room_id: room_id.to_owned(), - event_id: (*event_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Fetching state events"); - let collect = res - .pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::<Vec<_>>(); - - let state_vec = self - .fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map) - .await; - - let mut state: HashMap<_, Arc<EventId>> = HashMap::with_capacity(state_vec.len()); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; - - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - }, - hash_map::Entry::Occupied(_) => { - return Err(Error::bad_database( - "State event's type and state_key combination exists multiple times.", - )) - }, - } - } - - // The original create event must still be in the state - let create_shortstatekey = self - .services - .short - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); - - if state.get(&create_shortstatekey).map(AsRef::as_ref) != Some(&create_event.event_id) { - return Err(Error::bad_database("Incoming event refers to wrong create event.")); - } - - Ok(Some(state)) - }, - Err(e) => { - warn!("Fetching state for event failed: {}", e); - Err(e) - }, - } - } - - /// Find the event and auth it. Once the event is validated (steps 1 - 8) - /// it is appended to the outliers Tree. - /// - /// Returns pdu and if we fetched it over federation the raw json. - /// - /// a. Look in the main timeline (pduid_pdu tree) - /// b. Look at outlier pdu tree - /// c. Ask origin server over federation - /// d. TODO: Ask other servers over federation? - pub fn fetch_and_handle_outliers<'a>( - &'a self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId, - room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> AsyncRecursiveCanonicalJsonVec<'a> { - Box::pin(async move { - let back_off = |id| async { - match self - .services - .globals - .bad_event_ratelimiter - .write() - .expect("locked") - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)), - } - }; - - let mut events_with_auth_events = Vec::with_capacity(events.len()); - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { - trace!("Found {} in db", id); - events_with_auth_events.push((id, Some(local_pdu), vec![])); - continue; - } - - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); - let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i: u64 = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&*next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - info!("Backing off from {next_id}"); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - i = i.saturating_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } - - debug!("Fetching {} over federation.", next_id); - match self - .services - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Got {} over federation", next_id); - let Ok((calculated_event_id, value)) = - pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) - else { - back_off((*next_id).to_owned()).await; - continue; - }; - - if calculated_event_id != *next_id { - warn!( - "Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu - ); - } - - if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { - for auth_event in auth_events { - if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { - let a: Arc<EventId> = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } - } - } else { - warn!("Auth event list invalid"); - } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - }, - Err(e) => { - debug_error!("Failed to fetch event {next_id}: {e}"); - back_off((*next_id).to_owned()).await; - }, - } - } - events_with_auth_events.push((id, None, events_in_reverse_order)); - } - - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - self.services - .server_keys - .fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), - pub_key_map, - ) - .await - .unwrap_or_else(|e| { - warn!("Could not fetch all signatures for PDUs from {}: {:?}", origin, e); - }); - - let mut pdus = Vec::with_capacity(events_with_auth_events.len()); - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = self - .services - .globals - .bad_event_ratelimiter - .read() - .expect("locked") - .get(&**next_id) - { - // Exponential backoff - const MIN_DURATION: u64 = 5 * 60; - const MAX_DURATION: u64 = 60 * 60 * 24; - if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { - debug!("Backing off from {next_id}"); - continue; - } - } - - match self - .handle_outlier_pdu(origin, create_event, next_id, room_id, value.clone(), true, pub_key_map) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - }, - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()).await; - }, - } - } - } - pdus - }) - } - - #[allow(clippy::type_complexity)] - #[tracing::instrument(skip_all)] - async fn fetch_prev( - &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, initial_set: Vec<Arc<EventId>>, - ) -> Result<( - Vec<Arc<EventId>>, - HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>, - )> { - let mut graph: HashMap<Arc<EventId>, _> = HashMap::with_capacity(initial_set.len()); - let mut eventid_info = HashMap::new(); - let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; - - let first_pdu_in_room = self - .services - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - - let mut amount = 0; - - while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = self - .fetch_and_handle_outliers( - origin, - &[prev_event_id.clone()], - create_event, - room_id, - room_version_id, - pub_key_map, - ) - .await - .pop() - { - Self::check_room_id(room_id, &pdu)?; - - if amount > self.services.globals.max_fetch_prev_events() { - // Max limit reached - debug!( - "Max prev event limit reached! Limit: {}", - self.services.globals.max_fetch_prev_events() - ); - graph.insert(prev_event_id.clone(), HashSet::new()); - continue; - } - - if let Some(json) = json_opt.or_else(|| { - self.services - .outlier - .get_outlier_pdu_json(&prev_event_id) - .ok() - .flatten() - }) { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { - amount = amount.saturating_add(1); - for prev_prev in &pdu.prev_events { - if !graph.contains_key(prev_prev) { - todo_outlier_stack.push(prev_prev.clone()); - } - } - - graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect()); - } else { - // Time based check failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - - eventid_info.insert(prev_event_id.clone(), (pdu, json)); - } else { - // Get json failed, so this was not fetched over federation - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } else { - // Fetch and handle failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } - - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|e| { - error!("Error sorting prev events: {e}"); - Error::bad_database("Error sorting prev events") - })?; - - Ok((sorted, eventid_info)) - } - - /// Returns Ok if the acl allows the server - #[tracing::instrument(skip_all)] - pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = if let Some(acl) = - self.services - .state_accessor - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - trace!("ACL event found: {acl:?}"); - acl - } else { - trace!("No ACL event found"); - return Ok(()); - }; - - let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { - Ok(content) => { - trace!("Found ACL event contents: {content:?}"); - content - }, - Err(e) => { - warn!("Invalid ACL event: {e}"); - return Ok(()); - }, - }; - - if acl_event_content.allow.is_empty() { - warn!("Ignoring broken ACL event (allow key is empty)"); - // Ignore broken acl events - return Ok(()); - } - - if acl_event_content.is_allowed(server_name) { - trace!("server {server_name} is allowed by ACL"); - Ok(()) - } else { - debug!("Server {} was denied by room ACL in {}", server_name, room_id); - Err(Error::BadRequest(ErrorKind::forbidden(), "Server was denied by room ACL")) - } +fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result { + if pdu.room_id != room_id { + return Err!(Request(InvalidParam(error!( + pdu_event_id = ?pdu.event_id, + pdu_room_id = ?pdu.room_id, + ?room_id, + "Found event from room in room", + )))); } - fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { - if pdu.room_id != room_id { - warn!("Found event from room {} in room {}", pdu.room_id, room_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has wrong room id")); - } - Ok(()) - } + Ok(()) +} - fn get_room_version_id(create_event: &PduEvent) -> Result<RoomVersionId> { - let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()) - .map_err(|e| err!(Database("Invalid create event: {e}")))?; +fn get_room_version_id(create_event: &PduEvent) -> Result<RoomVersionId> { + let content: RoomCreateEventContent = create_event.get_content()?; + let room_version = content.room_version; - Ok(create_event_content.room_version) - } + Ok(room_version) +} - #[inline] - fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { - RoomVersion::new(room_version_id).expect("room version is supported") - } +#[inline] +fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { + RoomVersion::new(room_version_id).expect("room version is supported") } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index a7ffe193055b2af0916aca21388db4b0920cadd2..42f44deeca407cd9e288ff06f2ec7ad51758cf36 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,28 +1,27 @@ -use conduit::{debug_warn, err, pdu::gen_event_id_canonical_json, Err, Result}; -use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; +use conduit::{err, implement, pdu::gen_event_id_canonical_json, result::FlatOk, Result}; +use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; -impl super::Service { - pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_warn!("Error parsing incoming event {pdu:#?}"); - err!(BadServerResponse("Error parsing incoming event {e:?}")) - })?; +#[implement(super::Service)] +pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()) + .map_err(|e| err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))))?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(CanonicalJsonValue::as_str) + .map(RoomId::parse) + .flat_ok_or(err!(Request(InvalidParam("Invalid room_id in pdu"))))?; - let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { - return Err!("Server is not in room {room_id}"); - }; + let room_version_id = self + .services + .state + .get_room_version(&room_id) + .await + .map_err(|_| err!("Server is not in room {room_id}"))?; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { - // Event could not be converted to canonical json - return Err!(Request(InvalidParam("Could not convert event to canonical json."))); - }; + let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id) + .map_err(|e| err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))))?; - Ok((event_id, value, room_id)) - } + Ok((event_id, value, room_id)) } diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs new file mode 100644 index 0000000000000000000000000000000000000000..0c9525dd7f681e5264569bd4da5f8eeab3a0951a --- /dev/null +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -0,0 +1,101 @@ +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug, err, implement, utils::IterStream, Result}; +use futures::{FutureExt, StreamExt, TryFutureExt}; +use ruma::{ + state_res::{self, StateMap}, + EventId, RoomId, RoomVersionId, +}; + +use crate::rooms::state_compressor::CompressedStateEvent; + +#[implement(super::Service)] +#[tracing::instrument(skip_all, name = "resolve")] +pub async fn resolve_state( + &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, +) -> Result<Arc<HashSet<CompressedStateEvent>>> { + debug!("Loading current room state ids"); + let current_sstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; + + let current_state_ids = self + .services + .state_accessor + .state_full_ids(current_sstatehash) + .await?; + + let fork_states = [current_state_ids, incoming_state]; + let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); + for state in &fork_states { + let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + + let auth_chain: HashSet<Arc<EventId>> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + + auth_chain_sets.push(auth_chain); + } + + debug!("Loading fork states"); + let fork_states: Vec<StateMap<Arc<EventId>>> = fork_states + .into_iter() + .stream() + .then(|fork_state| { + fork_state + .into_iter() + .stream() + .filter_map(|(k, id)| { + self.services + .short + .get_statekey_from_short(k) + .map_ok_or_else(|_| None, move |(ty, st_key)| Some(((ty, st_key), id))) + }) + .collect() + }) + .collect() + .boxed() + .await; + + debug!("Resolving state"); + let lock = self.services.globals.stateres_mutex.lock(); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let state = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(error!("State resolution failed: {e:?}"))))?; + + drop(lock); + + debug!("State resolution done. Compressing state"); + let mut new_room_state = HashSet::new(); + for ((event_type, state_key), event_id) in state { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) + .await; + + let compressed = self + .services + .state_compressor + .compress_state_event(shortstatekey, &event_id) + .await; + + new_room_state.insert(compressed); + } + + Ok(Arc::new(new_room_state)) +} diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs new file mode 100644 index 0000000000000000000000000000000000000000..a200ab5689a046723eece8084723c3b053dd05f5 --- /dev/null +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -0,0 +1,178 @@ +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use conduit::{debug, err, implement, result::LogErr, utils::IterStream, PduEvent, Result}; +use futures::{FutureExt, StreamExt}; +use ruma::{ + state_res::{self, StateMap}, + EventId, RoomId, RoomVersionId, +}; + +// TODO: if we know the prev_events of the incoming event we can avoid the +#[implement(super::Service)] +// request and build the state from a known point and resolve if > 1 prev_event +#[tracing::instrument(skip_all, name = "state")] +pub(super) async fn state_at_incoming_degree_one( + &self, incoming_pdu: &Arc<PduEvent>, +) -> Result<Option<HashMap<u64, Arc<EventId>>>> { + let prev_event = &*incoming_pdu.prev_events[0]; + let Ok(prev_event_sstatehash) = self + .services + .state_accessor + .pdu_shortstatehash(prev_event) + .await + else { + return Ok(None); + }; + + let Ok(mut state) = self + .services + .state_accessor + .state_full_ids(prev_event_sstatehash) + .await + .log_err() + else { + return Ok(None); + }; + + debug!("Using cached state"); + let prev_pdu = self + .services + .timeline + .get_pdu(prev_event) + .await + .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; + + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .await; + + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu + } + + debug_assert!(!state.is_empty(), "should be returning None for empty HashMap result"); + + Ok(Some(state)) +} + +#[implement(super::Service)] +#[tracing::instrument(skip_all, name = "state")] +pub(super) async fn state_at_incoming_resolved( + &self, incoming_pdu: &Arc<PduEvent>, room_id: &RoomId, room_version_id: &RoomVersionId, +) -> Result<Option<HashMap<u64, Arc<EventId>>>> { + debug!("Calculating state at event using state res"); + let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); + + let mut okay = true; + for prev_eventid in &incoming_pdu.prev_events { + let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { + okay = false; + break; + }; + + let Ok(sstatehash) = self + .services + .state_accessor + .pdu_shortstatehash(prev_eventid) + .await + else { + okay = false; + break; + }; + + extremity_sstatehashes.insert(sstatehash, prev_event); + } + + if !okay { + return Ok(None); + } + + let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + for (sstatehash, prev_event) in extremity_sstatehashes { + let Ok(mut leaf_state) = self + .services + .state_accessor + .state_full_ids(sstatehash) + .await + else { + continue; + }; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); + // Now it's the state after the pdu + } + + let mut state = StateMap::with_capacity(leaf_state.len()); + let mut starting_events = Vec::with_capacity(leaf_state.len()); + for (k, id) in &leaf_state { + if let Ok((ty, st_key)) = self + .services + .short + .get_statekey_from_short(*k) + .await + .log_err() + { + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + state.insert((ty.to_string().into(), st_key), id.clone()); + } + + starting_events.push(id.borrow()); + } + + let auth_chain: HashSet<Arc<EventId>> = self + .services + .auth_chain + .get_event_ids(room_id, &starting_events) + .await? + .into_iter() + .collect(); + + auth_chain_sets.push(auth_chain); + fork_states.push(state); + } + + let lock = self.services.globals.stateres_mutex.lock(); + + let event_fetch = |event_id| self.event_fetch(event_id); + let event_exists = |event_id| self.event_exists(event_id); + let result = state_res::resolve(room_version_id, &fork_states, &auth_chain_sets, &event_fetch, &event_exists) + .await + .map_err(|e| err!(Database(warn!(?e, "State resolution on prev events failed.")))); + + drop(lock); + + let Ok(new_state) = result else { + return Ok(None); + }; + + new_state + .iter() + .stream() + .then(|((event_type, state_key), event_id)| { + self.services + .short + .get_or_create_shortstatekey(event_type, state_key) + .map(move |shortstatekey| (shortstatekey, event_id.clone())) + }) + .collect() + .map(Some) + .map(Ok) + .await +} diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs new file mode 100644 index 0000000000000000000000000000000000000000..2a1e46625caebf93e4fdca412e71f7bdc19c7ceb --- /dev/null +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -0,0 +1,298 @@ +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, + time::Instant, +}; + +use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; +use futures::{future::ready, StreamExt}; +use ruma::{ + api::client::error::ErrorKind, + events::{room::redaction::RoomRedactionEventContent, StateEventType, TimelineEventType}, + state_res::{self, EventTypeExt}, + CanonicalJsonValue, RoomId, RoomVersionId, ServerName, +}; + +use super::{get_room_version_id, to_room_version}; +use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPduId}; + +#[implement(super::Service)] +pub(super) async fn upgrade_outlier_to_timeline_pdu( + &self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent, + origin: &ServerName, room_id: &RoomId, +) -> Result<Option<RawPduId>> { + // Skip the PDU if we already have it as a timeline event + if let Ok(pduid) = self + .services + .timeline + .get_pdu_id(&incoming_pdu.event_id) + .await + { + return Ok(Some(pduid)); + } + + if self + .services + .pdu_metadata + .is_event_soft_failed(&incoming_pdu.event_id) + .await + { + return Err!(Request(InvalidParam("Event has been soft failed"))); + } + + debug!("Upgrading to timeline pdu"); + let timer = Instant::now(); + let room_version_id = get_room_version_id(create_event)?; + + // 10. Fetch missing state and auth chain events by calling /state_ids at + // backwards extremities doing all the checks in this list starting at 1. + // These are not timeline events. + + debug!("Resolving state at event"); + let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 { + self.state_at_incoming_degree_one(&incoming_pdu).await? + } else { + self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id) + .await? + }; + + if state_at_incoming_event.is_none() { + state_at_incoming_event = self + .fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id) + .await?; + } + + let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); + let room_version = to_room_version(&room_version_id); + + debug!("Performing auth check"); + // 11. Check the auth of the event passes based on the state of the event + let state_fetch_state = &state_at_incoming_event; + let state_fetch = |k: &'static StateEventType, s: String| async move { + let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + + let event_id = state_fetch_state.get(&shortstatekey)?; + self.services.timeline.get_pdu(event_id).await.ok() + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // TODO: third party invite + |k, s| state_fetch(k, s.to_owned()), + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + if !auth_check { + return Err!(Request(Forbidden("Event has failed auth check with state at the event."))); + } + + debug!("Gathering auth events"); + let auth_events = self + .services + .state + .get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + ) + .await?; + + let state_fetch = |k: &'static StateEventType, s: &str| { + let key = k.with_state_key(s); + ready(auth_events.get(&key).cloned()) + }; + + let auth_check = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None, // third-party invite + state_fetch, + ) + .await + .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; + + // Soft fail check before doing state res + debug!("Performing soft-fail check"); + let soft_fail = { + use RoomVersionId::*; + + !auth_check + || incoming_pdu.kind == TimelineEventType::RoomRedaction + && match room_version_id { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + if let Some(redact_id) = &incoming_pdu.redacts { + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? + } else { + false + } + }, + _ => { + let content: RoomRedactionEventContent = incoming_pdu.get_content()?; + if let Some(redact_id) = &content.redacts { + !self + .services + .state_accessor + .user_can_redact(redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, true) + .await? + } else { + false + } + }, + } + }; + + // 13. Use state resolution to find new room state + + // We start looking at current room state now, so lets lock the room + trace!("Locking the room"); + let state_lock = self.services.state.mutex.lock(room_id).await; + + // Now we calculate the set of extremities this room has after the incoming + // event has been applied. We start with the previous extremities (aka leaves) + trace!("Calculating extremities"); + let mut extremities: HashSet<_> = self + .services + .state + .get_forward_extremities(room_id) + .map(ToOwned::to_owned) + .collect() + .await; + + // Remove any forward extremities that are referenced by this incoming event's + // prev_events + trace!( + "Calculated {} extremities; checking against {} prev_events", + extremities.len(), + incoming_pdu.prev_events.len() + ); + for prev_event in &incoming_pdu.prev_events { + extremities.remove(&(**prev_event)); + } + + // Only keep those extremities were not referenced yet + let mut retained = HashSet::new(); + for id in &extremities { + if !self + .services + .pdu_metadata + .is_event_referenced(room_id, id) + .await + { + retained.insert(id.clone()); + } + } + + extremities.retain(|id| retained.contains(id)); + debug!("Retained {} extremities. Compressing state", extremities.len()); + + let mut state_ids_compressed = HashSet::new(); + for (shortstatekey, id) in &state_at_incoming_event { + state_ids_compressed.insert( + self.services + .state_compressor + .compress_state_event(*shortstatekey, id) + .await, + ); + } + + let state_ids_compressed = Arc::new(state_ids_compressed); + + if incoming_pdu.state_key.is_some() { + debug!("Event is a state-event. Deriving new room state"); + + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) + .await; + + let event_id = &incoming_pdu.event_id; + state_after.insert(shortstatekey, event_id.clone()); + } + + let new_room_state = self + .resolve_state(room_id, &room_version_id, state_after) + .await?; + + // Set the new room state to the resolved state + debug!("Forcing new room state"); + let HashSetCompressStateEvent { + shortstatehash, + added, + removed, + } = self + .services + .state_compressor + .save_state(room_id, new_room_state) + .await?; + + self.services + .state + .force_state(room_id, shortstatehash, added, removed, &state_lock) + .await?; + } + + // 14. Check if the event passes auth based on the "current state" of the room, + // if not soft fail it + if soft_fail { + debug!("Soft failing event"); + self.services + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + // Soft fail, we keep the event as an outlier but don't add it to the timeline + warn!("Event was soft failed: {incoming_pdu:?}"); + self.services + .pdu_metadata + .mark_event_soft_failed(&incoming_pdu.event_id); + + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + } + + trace!("Appending pdu to timeline"); + extremities.insert(incoming_pdu.event_id.clone().into()); + + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + let pdu_id = self + .services + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.into_iter().collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + // Event has passed all auth/stateres checks + drop(state_lock); + debug_info!( + elapsed = ?timer.elapsed(), + "Accepted", + ); + + Ok(pdu_id) +} diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs deleted file mode 100644 index 073d45f565bb00c91c7b8b4529aaf3db75b90304..0000000000000000000000000000000000000000 --- a/src/service/rooms/lazy_loading/data.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, RoomId, UserId}; - -pub(super) struct Data { - lazyloadedids: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - lazyloadedids: db["lazyloadedids"].clone(), - } - } - - pub(super) fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result<bool> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - pub(super) fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - - Ok(()) - } - - pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xFF); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } -} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 0a9d4cf29d4959aab21a79f56f12735a2d32b6d9..7a4da2a64cb7a86c83851644806cda07c3a26409 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,21 +1,26 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, sync::{Arc, Mutex}, }; -use conduit::{PduCount, Result}; +use conduit::{ + implement, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, Result, +}; +use database::{Interfix, Map}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use self::data::Data; - pub struct Service { - pub lazy_load_waiting: Mutex<LazyLoadWaiting>, + lazy_load_waiting: Mutex<LazyLoadWaiting>, db: Data, } +struct Data { + lazyloadedids: Arc<Map>, +} + type LazyLoadWaiting = HashMap<LazyLoadWaitingKey, LazyLoadWaitingVal>; type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount); type LazyLoadWaitingVal = HashSet<OwnedUserId>; @@ -23,8 +28,10 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - lazy_load_waiting: Mutex::new(HashMap::new()), - db: Data::new(args.db), + lazy_load_waiting: LazyLoadWaiting::new().into(), + db: Data { + lazyloadedids: args.db["lazyloadedids"].clone(), + }, })) } @@ -40,47 +47,52 @@ fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result<bool> { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +#[inline] +pub async fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, +) -> bool { + let key = (user_id, device_id, room_id, ll_user); + self.db.lazyloadedids.qry(&key).await.is_ok() +} - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_mark_sent( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, - count: PduCount, - ) { - self.lazy_load_waiting - .lock() - .expect("locked") - .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_mark_sent( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount, +) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count); - #[tracing::instrument(skip(self), level = "debug")] - pub async fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - self.db - .lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; - } else { - // Ignore - } + self.lazy_load_waiting + .lock() + .expect("locked") + .insert(key, lazy_load); +} - Ok(()) - } +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) { + let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since); - #[tracing::instrument(skip(self), level = "debug")] - pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id) + let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else { + return; + }; + + for ll_id in &user_ids { + let key = (user_id, device_id, room_id, ll_id); + self.db.lazyloadedids.put_raw(key, []); } } + +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) { + let prefix = (user_id, device_id, room_id, Interfix); + self.db + .lazyloadedids + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.lazyloadedids.remove(key)) + .await; +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs deleted file mode 100644 index efe681b1bc2f2edd8d94061b2da589fd41e2d759..0000000000000000000000000000000000000000 --- a/src/service/rooms/metadata/data.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use conduit::{error, utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, RoomId}; - -use crate::{rooms, Dep}; - -pub(super) struct Data { - disabledroomids: Arc<Map>, - bannedroomids: Arc<Map>, - roomid_shortroomid: Arc<Map>, - pduid_pdu: Arc<Map>, - services: Services, -} - -struct Services { - short: Dep<rooms::short::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - disabledroomids: db["disabledroomids"].clone(), - bannedroomids: db["bannedroomids"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - pduid_pdu: db["pduid_pdu"].clone(), - services: Services { - short: args.depend::<rooms::short::Service>("rooms::short"), - }, - } - } - - pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> { - let prefix = match self.services.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - pub(super) fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } - - #[inline] - pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - #[inline] - pub(super) fn is_banned(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) - } - - #[inline] - pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - pub(super) fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; - - Ok(room_id) - }, - )) - } -} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 7415c53b76fd349e4acb1c145ae9790a1c40e4b0..4ee390a5c22dfc57df0c33b3e31fcd24d5d80e46 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,51 +1,92 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, RoomId}; +use conduit::{implement, utils::stream::TryIgnore, Result}; +use database::Map; +use futures::{Stream, StreamExt}; +use ruma::RoomId; -use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + disabledroomids: Arc<Map>, + bannedroomids: Arc<Map>, + roomid_shortroomid: Arc<Map>, + pduid_pdu: Arc<Map>, +} + +struct Services { + short: Dep<rooms::short::Service>, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + disabledroomids: args.db["disabledroomids"].clone(), + bannedroomids: args.db["bannedroomids"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + }, + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Checks if a room exists. - #[inline] - pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) } +#[implement(Service)] +pub async fn exists(&self, room_id: &RoomId) -> bool { + let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else { + return false; + }; + + // Look for PDUs in that room. + self.db + .pduid_pdu + .keys_prefix_raw(&prefix) + .ignore_err() + .next() + .await + .is_some() +} - #[must_use] - pub fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { self.db.iter_ids() } +#[implement(Service)] +pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() } - #[inline] - pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { self.db.is_disabled(room_id) } +#[implement(Service)] +#[inline] +pub fn disable_room(&self, room_id: &RoomId, disabled: bool) { + if disabled { + self.db.disabledroomids.insert(room_id, []); + } else { + self.db.disabledroomids.remove(room_id); + } +} - #[inline] - pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - self.db.disable_room(room_id, disabled) +#[implement(Service)] +#[inline] +pub fn ban_room(&self, room_id: &RoomId, banned: bool) { + if banned { + self.db.bannedroomids.insert(room_id, []); + } else { + self.db.bannedroomids.remove(room_id); } +} - #[inline] - pub fn is_banned(&self, room_id: &RoomId) -> Result<bool> { self.db.is_banned(room_id) } +#[implement(Service)] +pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() } - #[inline] - pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } +#[implement(Service)] +#[inline] +pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.get(room_id).await.is_ok() } - #[inline] - #[must_use] - pub fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { - self.db.list_banned_rooms() - } -} +#[implement(Service)] +#[inline] +pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.get(room_id).await.is_ok() } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs deleted file mode 100644 index aa804721b659aad5582fe5069b5560e4ebdc5287..0000000000000000000000000000000000000000 --- a/src/service/rooms/outlier/data.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::Arc; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{CanonicalJsonObject, EventId}; - -use crate::PduEvent; - -pub(super) struct Data { - eventid_outlierpdu: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - } - } - - pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } -} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 22bd2092a2a2afe91a89fadc7f138f0808651ae9..03e7783892307fbad29bfa19fd958bcb71e4c2e3 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,9 +1,7 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Deserialized, Json, Map}; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -12,31 +10,45 @@ pub struct Service { db: Data, } +struct Data { + eventid_outlierpdu: Arc<Map>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.db.get_outlier_pdu_json(event_id) - } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> { + self.db + .eventid_outlierpdu + .get(event_id) + .await + .deserialized() +} - /// Returns the pdu from the outlier tree. - /// - /// TODO: use this? - #[allow(dead_code)] - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { self.db.get_outlier_pdu(event_id) } +/// Returns the pdu from the outlier tree. +#[implement(Service)] +pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result<PduEvent> { + self.db + .eventid_outlierpdu + .get(event_id) + .await + .deserialized() +} - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu), level = "debug")] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.db.add_pdu_outlier(event_id, pdu) - } +/// Append the PDU as an outlier. +#[implement(Service)] +#[tracing::instrument(skip(self, pdu), level = "debug")] +pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) { + self.db.eventid_outlierpdu.raw_put(event_id, Json(pdu)); } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index d1649da813464176672b528808b74ee3a70be3c0..b06e988e819c549bc72f88b94b0baccffce97221 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,10 +1,23 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, PduCount, PduEvent, Result}; +use arrayvec::ArrayVec; +use conduit::{ + result::LogErr, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, + PduCount, PduEvent, +}; use database::Map; -use ruma::{EventId, RoomId, UserId}; - -use crate::{rooms, Dep}; +use futures::{Stream, StreamExt}; +use ruma::{api::Direction, EventId, RoomId, UserId}; + +use crate::{ + rooms, + rooms::{ + short::{ShortEventId, ShortRoomId}, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub(super) struct Data { tofrom_relation: Arc<Map>, @@ -17,8 +30,7 @@ struct Services { timeline: Dep<rooms::timeline::Service>, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>; +pub(super) type PdusIterItem = (PduCount, PduEvent); impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { @@ -33,75 +45,60 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> { - let mut key = to.to_be_bytes().to_vec(); - key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) + pub(super) fn add_relation(&self, from: u64, to: u64) { + const BUFSIZE: usize = size_of::<u64>() * 2; + + let key: &[u64] = &[to, from]; + self.tofrom_relation.aput_raw::<BUFSIZE, _, _>(key, []); } - pub(super) fn relations_until<'a>( - &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, - ) -> Result<PdusIterator<'a>> { - let prefix = target.to_be_bytes().to_vec(); - let mut current = prefix.clone(); - - let count_raw = match until { - PduCount::Normal(x) => x.saturating_sub(1), - PduCount::Backfilled(x) => { - current.extend_from_slice(&0_u64.to_be_bytes()); - u64::MAX.saturating_sub(x).saturating_sub(1) - }, - }; - current.extend_from_slice(&count_raw.to_be_bytes()); - - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(size_of::<u64>())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; - - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) + pub(super) fn get_relations<'a>( + &'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction, + ) -> impl Stream<Item = PdusIterItem> + Send + '_ { + let mut current = ArrayVec::<u8, 16>::new(); + current.extend(target.to_be_bytes()); + current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes()); + let current = current.as_slice(); + match dir { + Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(), + Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(), + } + .ignore_err() + .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) + .map(|to_from| u64_from_u8(&to_from[8..16])) + .map(PduCount::from_unsigned) + .filter_map(move |shorteventid| async move { + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid, + } + .into(); + + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + + if pdu.sender != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + Some((shorteventid, pdu)) + }) } - pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { + pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) { for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; + let key = (room_id, prev); + self.referencedevents.put_raw(key, []); } - - Ok(()) } - pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) + pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + let key = (room_id, event_id); + self.referencedevents.qry(&key).await.is_ok() } - pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) - } + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) { self.softfailedeventids.insert(event_id, []); } - pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) + pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.softfailedeventids.get(event_id).await.is_ok() } } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index d9eaf3244cef5b8e67300e8d38c9c307f88b9765..82d2ee35b12191f9f0fc283ab2768b9aba7e11c2 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,16 +1,11 @@ mod data; - use std::sync::Arc; -use conduit::{PduCount, PduEvent, Result}; -use ruma::{ - api::{client::relations::get_relating_events, Direction}, - events::{relation::RelationType, TimelineEventType}, - uint, EventId, RoomId, UInt, UserId, -}; -use serde::Deserialize; +use conduit::{PduCount, Result}; +use futures::StreamExt; +use ruma::{api::Direction, EventId, RoomId, UserId}; -use self::data::Data; +use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; pub struct Service { @@ -20,26 +15,14 @@ pub struct Service { struct Services { short: Dep<rooms::short::Service>, - state_accessor: Dep<rooms::state_accessor::Service>, timeline: Dep<rooms::timeline::Service>, } -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelType { - rel_type: RelationType, -} -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractRelType, -} - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { services: Services { short: args.depend::<rooms::short::Service>("rooms::short"), - state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), }, db: Data::new(&args), @@ -51,152 +34,83 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { #[tracing::instrument(skip(self, from, to), level = "debug")] - pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + pub fn add_relation(&self, from: PduCount, to: PduCount) { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), _ => { // TODO: Relations with backfilled pdus - - Ok(()) }, } } #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>, - filter_rel_type: &Option<RelationType>, from: &Option<String>, to: &Option<String>, limit: &Option<UInt>, - recurse: bool, dir: Direction, - ) -> Result<get_relating_events::v1::Response> { - let from = match from { - Some(from) => PduCount::try_from_string(from)?, - None => match dir { - Direction::Forward => PduCount::min(), - Direction::Backward => PduCount::max(), - }, - }; - - let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); - - // Use limit or else 10, with maximum 100 - let limit = limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) - .min(100); - - // Spec (v1.10) recommends depth of at least 3 - let depth: u8 = if recurse { - 3 - } else { - 1 - }; - - let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?; - let events: Vec<_> = relations_until // TODO: should be relations_after - .iter() - .filter(|(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - .take(limit) - .filter(|(_, pdu)| { - self.services - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to` - .collect(); - - let next_token = events.last().map(|(count, _)| count).copied(); - - let events_chunk: Vec<_> = match dir { - Direction::Forward => events - .into_iter() - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(), - Direction::Backward => events - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(), - }; - - Ok(get_relating_events::v1::Response { - chunk: events_chunk, - next_batch: next_token.map(|t| t.stringify()), - prev_batch: Some(from.stringify()), - recursion_depth: if recurse { - Some(depth.into()) - } else { - None - }, - }) - } - - pub fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, - ) -> Result<Vec<(PduCount, PduEvent)>> { - let room_id = self.services.short.get_or_create_shortroomid(room_id)?; - #[allow(unknown_lints)] - #[allow(clippy::manual_unwrap_or_default)] - let target = match self.services.timeline.get_pdu_count(target)? { - Some(PduCount::Normal(c)) => c, + pub async fn get_relations( + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8, + dir: Direction, + ) -> Vec<PdusIterItem> { + let room_id = self.services.short.get_or_create_shortroomid(room_id).await; + + let target = match self.services.timeline.get_pdu_count(target).await { + Ok(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator }; - self.db - .relations_until(user_id, room_id, target, until) - .map(|mut relations| { - let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect(); - let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect(); - - while let Some(stack_pdu) = stack.pop() { - let target = match stack_pdu.0 .0 { - PduCount::Normal(c) => c, - // TODO: Support backfilled relations - PduCount::Backfilled(_) => 0, // This will result in an empty iterator - }; - - if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { - for relation in relations.flatten() { - if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); - } - - pdus.push(relation); - } - } + let mut pdus: Vec<_> = self + .db + .get_relations(user_id, room_id, target, from, dir) + .collect() + .await; + + let mut stack: Vec<_> = pdus.iter().map(|pdu| (pdu.clone(), 1)).collect(); + + 'limit: while let Some(stack_pdu) = stack.pop() { + let target = match stack_pdu.0 .0 { + PduCount::Normal(c) => c, + // TODO: Support backfilled relations + PduCount::Backfilled(_) => 0, // This will result in an empty iterator + }; + + let relations: Vec<_> = self + .db + .get_relations(user_id, room_id, target, from, dir) + .collect() + .await; + + for relation in relations { + if stack_pdu.1 < max_depth { + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } - pdus.sort_by(|a, b| a.0.cmp(&b.0)); - pdus - }) + pdus.push(relation); + if pdus.len() >= limit { + break 'limit; + } + } + } + + pdus } + #[inline] #[tracing::instrument(skip_all, level = "debug")] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { - self.db.mark_as_referenced(room_id, event_ids) + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) { + self.db.mark_as_referenced(room_id, event_ids); } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { - self.db.is_event_referenced(room_id, event_id) + pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool { + self.db.is_event_referenced(room_id, event_id).await } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } + pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) } + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { self.db.is_event_soft_failed(event_id) } + pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool { + self.db.is_event_soft_failed(event_id).await + } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 0c156df38a5dc97546b465ad9c94d24c6abce72d..80a35e8813cdb63ca74776d29fdbba28b2305c64 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,10 +1,18 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Deserialized, Json, Map}; +use futures::{Stream, StreamExt}; +use ruma::{ + events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, + serde::Raw, + CanonicalJsonObject, OwnedUserId, RoomId, UserId, +}; -use super::AnySyncEphemeralRoomEventIter; use crate::{globals, Dep}; pub(super) struct Data { @@ -18,6 +26,8 @@ struct Services { globals: Dep<globals::Service>, } +pub(super) type ReceiptItem = (OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>); + impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = &args.db; @@ -31,116 +41,82 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + type KeyVal<'a> = (&'a RoomId, u64, &'a UserId); // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } - - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xFF); - room_latest_id.extend_from_slice(user_id.as_bytes()); - - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) + let last_possible_key = (room_id, u64::MAX); + self.readreceiptid_readreceipt + .rev_keys_from(&last_possible_key) + .ignore_err() + .ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id) + .ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u))) + .ready_for_each(|old: KeyVal<'_>| self.readreceiptid_readreceipt.del(old)) + .await; + + let count = self.services.globals.next_count().unwrap(); + let latest_id = (room_id, count, user_id); + self.readreceiptid_readreceipt.put(latest_id, Json(event)); } - pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> { + pub(super) fn readreceipts_since<'a>( + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream<Item = ReceiptItem> + Send + 'a { + let after_since = since.saturating_add(1); // +1 so we don't send the event at since + let first_possible_edu = (room_id, after_since); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); let prefix2 = prefix.clone(); - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since - - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count_offset = prefix.len().saturating_add(size_of::<u64>()); - let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id_offset = count_offset.saturating_add(1); - let user_id = UserId::parse( - utils::string_from_bytes(&k[user_id_offset..]) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, - ) + self.readreceiptid_readreceipt + .stream_from_raw(&first_possible_edu) + .ignore_err() + .ready_take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count_offset = prefix.len().saturating_add(size_of::<u64>()); + let user_id_offset = count_offset.saturating_add(1); + + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + + let user_id_str = utils::string_from_bytes(&k[user_id_offset..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?; + + let user_id = UserId::parse(user_id_str) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v) - .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), - )) - }), - ) - } + let mut json = serde_json::from_slice::<CanonicalJsonObject>(v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; - pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); + json.remove("room_id"); - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + let event = Raw::from_json(serde_json::value::to_raw_value(&json)?); - self.roomuserid_lastprivatereadupdate - .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) + Ok((user_id, count, event)) + }) + .ignore_err() } - pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + let key = (room_id, user_id); + let next_count = self.services.globals.next_count().unwrap(); + + self.roomuserid_privateread.put(key, count); + self.roomuserid_lastprivatereadupdate.put(key, next_count); } - pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) + pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> { + let key = (room_id, user_id); + self.roomuserid_privateread.qry(&key).await.deserialized() + } + + pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.roomuserid_lastprivatereadupdate + .qry(&key) + .await + .deserialized() + .unwrap_or(0) } } diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index da11e2a0f0450b7e7be3154a89dcac99991638c6..ec34361e0be8b0896f18d845ce9afd6539cd383c 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -3,16 +3,17 @@ use std::{collections::BTreeMap, sync::Arc}; use conduit::{debug, Result}; -use data::Data; +use futures::Stream; use ruma::{ events::{ receipt::{ReceiptEvent, ReceiptEventContent}, - AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent, + SyncEphemeralRoomEvent, }, serde::Raw, - OwnedUserId, RoomId, UserId, + RoomId, UserId, }; +use self::data::{Data, ReceiptItem}; use crate::{sending, Dep}; pub struct Service { @@ -24,9 +25,6 @@ struct Services { sending: Dep<sending::Service>, } -type AnySyncEphemeralRoomEventIter<'a> = - Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>; - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { @@ -42,44 +40,53 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Replaces the previous read receipt. - pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event)?; - self.services.sending.flush_room(room_id)?; - - Ok(()) + pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) { + self.db.readreceipt_update(user_id, room_id, event).await; + self.services + .sending + .flush_room(room_id) + .await + .expect("room flush failed"); } /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> impl Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a { + &'a self, room_id: &'a RoomId, since: u64, + ) -> impl Stream<Item = ReceiptItem> + Send + 'a { self.db.readreceipts_since(room_id, since) } /// Sets a private read marker at `count`. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + self.db.private_read_set(room_id, user_id, count); } /// Returns the private read marker. + #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - self.db.private_read_get(room_id, user_id) + pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> { + self.db.private_read_get(room_id, user_id).await } /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - self.db.last_privateread_update(user_id, room_id) + #[inline] + pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + self.db.last_privateread_update(user_id, room_id).await } } #[must_use] -pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>> { +pub fn pack_receipts<I>(receipts: I) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>> +where + I: Iterator<Item = ReceiptItem>, +{ let mut json = BTreeMap::new(); - for (_user, _count, value) in receipts.flatten() { + for (_, _, value) in receipts { let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get()); if let Ok(value) = receipt { for (event, receipt) in value.content { diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs deleted file mode 100644 index a0086095bda93d24fd6350da446c9950c968221d..0000000000000000000000000000000000000000 --- a/src/service/rooms/search/data.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Result}; -use database::Map; -use ruma::RoomId; - -use crate::{rooms, Dep}; - -type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; - -pub(super) struct Data { - tokenids: Arc<Map>, - services: Services, -} - -struct Services { - short: Dep<rooms::short::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - tokenids: db["tokenids"].clone(), - services: Services { - short: args.depend::<rooms::short::Service>("rooms::short"), - }, - } - } - - pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - let batch = tokenize(message_body) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - (key, Vec::<u8>::new()) - }) - .collect::<Vec<_>>(); - - self.tokenids - .insert_batch(batch.iter().map(database::KeyVal::from)) - } - - pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - let batch = tokenize(message_body).map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - key - }); - - for token in batch { - self.tokenids.remove(&token)?; - } - - Ok(()) - } - - pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let words: Vec<_> = tokenize(search_string).collect(); - - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); - - let Some(common_elements) = utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) else { - return Ok(None); - }; - - Ok(Some((Box::new(common_elements), words))) - } -} - -/// Splits a string into tokens used as keys in the search inverted index -/// -/// This may be used to tokenize both message bodies (for indexing) or search -/// queries (for querying). -fn tokenize(body: &str) -> impl Iterator<Item = String> + '_ { - body.split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) -} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 8caa0ce3525902037643565bc7756fb942150f49..1af37d9e5fd2b862b8d17a4f807a71c4d54da601 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,40 +1,217 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; -use ruma::RoomId; +use arrayvec::ArrayVec; +use conduit::{ + implement, + utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt}, + PduCount, PduEvent, Result, +}; +use database::{keyval::Val, Map}; +use futures::{Stream, StreamExt}; +use ruma::{api::client::search::search_events::v3::Criteria, RoomId, UserId}; + +use crate::{ + rooms, + rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + tokenids: Arc<Map>, +} + +struct Services { + short: Dep<rooms::short::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + timeline: Dep<rooms::timeline::Service>, +} + +#[derive(Clone, Debug)] +pub struct RoomQuery<'a> { + pub room_id: &'a RoomId, + pub user_id: Option<&'a UserId>, + pub criteria: &'a Criteria, + pub limit: usize, + pub skip: usize, } +type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>; + +const TOKEN_ID_MAX_LEN: usize = size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>(); +const WORD_MAX_LEN: usize = 50; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + tokenids: args.db["tokenids"].clone(), + }, + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - #[tracing::instrument(skip(self), level = "debug")] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.index_pdu(shortroomid, pdu_id, message_body) - } +#[implement(Service)] +pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { + let batch = tokenize(message_body) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here + (key, Vec::<u8>::new()) + }) + .collect::<Vec<_>>(); + + self.db.tokenids.insert_batch(batch.iter()); +} - #[tracing::instrument(skip(self), level = "debug")] - pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.deindex_pdu(shortroomid, pdu_id, message_body) +#[implement(Service)] +pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { + let batch = tokenize(message_body).map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here + key + }); + + for token in batch { + self.db.tokenids.remove(&token); } +} + +#[implement(Service)] +pub async fn search_pdus<'a>( + &'a self, query: &'a RoomQuery<'a>, +) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> { + let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; + + let count = pdu_ids.len(); + let pdus = pdu_ids + .into_iter() + .stream() + .filter_map(move |result_pdu_id: RawPduId| async move { + self.services + .timeline + .get_pdu_from_id(&result_pdu_id) + .await + .ok() + }) + .ready_filter(|pdu| !pdu.is_redacted()) + .ready_filter(|pdu| pdu.matches(&query.criteria.filter)) + .filter_map(move |pdu| async move { + self.services + .state_accessor + .user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) + }) + .skip(query.skip) + .take(query.limit); + + Ok((count, pdus)) +} - #[tracing::instrument(skip(self), level = "debug")] - pub fn search_pdus<'a>( - &'a self, room_id: &RoomId, search_string: &str, - ) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> { - self.db.search_pdus(room_id, search_string) +// result is modeled as a stream such that callers don't have to be refactored +// though an additional async/wrap still exists for now +#[implement(Service)] +pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<Item = RawPduId> + Send + '_> { + let shortroomid = self.services.short.get_shortroomid(query.room_id).await?; + + let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await; + + let iters = pdu_ids.into_iter().map(IntoIterator::into_iter); + + Ok(set::intersection(iters).stream()) +} + +#[implement(Service)] +async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec<Vec<RawPduId>> { + tokenize(&query.criteria.search_term) + .stream() + .then(|word| async move { + self.search_pdu_ids_query_words(shortroomid, &word) + .collect::<Vec<_>>() + .await + }) + .collect::<Vec<_>>() + .await +} + +/// Iterate over PduId's containing a word +#[implement(Service)] +fn search_pdu_ids_query_words<'a>( + &'a self, shortroomid: ShortRoomId, word: &'a str, +) -> impl Stream<Item = RawPduId> + Send + '_ { + self.search_pdu_ids_query_word(shortroomid, word) + .map(move |key| -> RawPduId { + let key = &key[prefix_len(word)..]; + key.into() + }) +} + +/// Iterate over raw database results for a word +#[implement(Service)] +fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ { + // rustc says const'ing this not yet stable + let end_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::max(), } + .into(); + + // Newest pdus first + let end = make_tokenid(shortroomid, word, &end_id); + let prefix = make_prefix(shortroomid, word); + self.db + .tokenids + .rev_raw_keys_from(&end) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) +} + +/// Splits a string into tokens used as keys in the search inverted index +/// +/// This may be used to tokenize both message bodies (for indexing) or search +/// queries (for querying). +fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ { + body.split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= WORD_MAX_LEN) + .map(str::to_lowercase) +} + +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId { + let mut key = make_prefix(shortroomid, word); + key.extend_from_slice(pdu_id.as_ref()); + key +} + +fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId { + let mut key = TokenId::new(); + key.extend_from_slice(&shortroomid.to_be_bytes()); + key.extend_from_slice(word.as_bytes()); + key.push(database::SEP); + key +} + +fn prefix_len(word: &str) -> usize { + size_of::<ShortRoomId>() + .saturating_add(word.len()) + .saturating_add(1) } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs deleted file mode 100644 index 17fbb64e8dd5d6bbdc08dd82442b26f64c335680..0000000000000000000000000000000000000000 --- a/src/service/rooms/short/data.rs +++ /dev/null @@ -1,195 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, warn, Error, Result}; -use database::Map; -use ruma::{events::StateEventType, EventId, RoomId}; - -use crate::{globals, Dep}; - -pub(super) struct Data { - eventid_shorteventid: Arc<Map>, - shorteventid_eventid: Arc<Map>, - statekey_shortstatekey: Arc<Map>, - shortstatekey_statekey: Arc<Map>, - roomid_shortroomid: Arc<Map>, - statehash_shortstatehash: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - eventid_shorteventid: db["eventid_shorteventid"].clone(), - shorteventid_eventid: db["shorteventid_eventid"].clone(), - statekey_shortstatekey: db["statekey_shortstatekey"].clone(), - shortstatekey_statekey: db["shortstatekey_statekey"].clone(), - roomid_shortroomid: db["roomid_shortroomid"].clone(), - statehash_shortstatehash: db["statehash_shortstatehash"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - }, - } - } - - pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { - let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { - utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? - } else { - let shorteventid = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - }; - - Ok(short) - } - - pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> { - let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len()); - let keys = event_ids - .iter() - .map(|id| id.as_bytes()) - .collect::<Vec<&[u8]>>(); - for (i, short) in self - .eventid_shorteventid - .multi_get(&keys)? - .iter() - .enumerate() - { - #[allow(clippy::single_match_else)] - match short { - Some(short) => ret.push( - utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - ), - None => { - let short = self.services.globals.next_count()?; - self.eventid_shorteventid - .insert(keys[i], &short.to_be_bytes())?; - self.shorteventid_eventid - .insert(&short.to_be_bytes(), keys[i])?; - - debug_assert!(ret.len() == i, "position of result must match input"); - ret.push(short); - }, - } - } - - Ok(ret) - } - - pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - Ok(short) - } - - pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> { - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xFF); - statekey_vec.extend_from_slice(state_key.as_bytes()); - - let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { - utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? - } else { - let shortstatekey = self.services.globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - }; - - Ok(short) - } - - pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - Ok(event_id) - } - - pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xFF); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); - - let state_key = utils::string_from_bytes(statekey_bytes) - .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; - - let result = (event_type, state_key); - - Ok(result) - } - - /// Returns (shortstatehash, already_existed) - pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { - ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ) - } else { - let shortstatehash = self.services.globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - }) - } - - pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) - .transpose() - } - - pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { - Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { - utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? - } else { - let short = self.services.globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - }) - } -} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index bfe0e9a0ef044a621fa5281bd06d57f18da9d23b..703df796aff8264c0e2335e176d472c314774dff 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,59 +1,238 @@ -mod data; +use std::{mem::size_of_val, sync::Arc}; -use std::sync::Arc; - -use conduit::Result; +pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; +use conduit::{err, implement, utils, Result}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{events::StateEventType, EventId, RoomId}; -use self::data::Data; +use crate::{globals, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + eventid_shorteventid: Arc<Map>, + shorteventid_eventid: Arc<Map>, + statekey_shortstatekey: Arc<Map>, + shortstatekey_statekey: Arc<Map>, + roomid_shortroomid: Arc<Map>, + statehash_shortstatehash: Arc<Map>, } +struct Services { + globals: Dep<globals::Service>, +} + +pub type ShortStateHash = ShortId; +pub type ShortStateKey = ShortId; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + eventid_shorteventid: args.db["eventid_shorteventid"].clone(), + shorteventid_eventid: args.db["shorteventid_eventid"].clone(), + statekey_shortstatekey: args.db["statekey_shortstatekey"].clone(), + shortstatekey_statekey: args.db["shortstatekey_statekey"].clone(), + roomid_shortroomid: args.db["roomid_shortroomid"].clone(), + statehash_shortstatehash: args.db["statehash_shortstatehash"].clone(), + }, + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { - self.db.get_or_create_shorteventid(event_id) - } +#[implement(Service)] +pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { + const BUFSIZE: usize = size_of::<ShortEventId>(); - pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> { - self.db.multi_get_or_create_shorteventid(event_ids) + if let Ok(shorteventid) = self.get_shorteventid(event_id).await { + return shorteventid; } - pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> { - self.db.get_shortstatekey(event_type, state_key) - } + let shorteventid = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shorteventid) == BUFSIZE, "buffer requirement changed"); - pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> { - self.db.get_or_create_shortstatekey(event_type, state_key) - } + self.db + .eventid_shorteventid + .raw_aput::<BUFSIZE, _, _>(event_id, shorteventid); - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { - self.db.get_eventid_from_short(shorteventid) - } + self.db + .shorteventid_eventid + .aput_raw::<BUFSIZE, _, _>(shorteventid, event_id); - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey) - } + shorteventid +} - /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) +#[implement(Service)] +pub fn multi_get_or_create_shorteventid<'a>( + &'a self, event_ids: &'a [&EventId], +) -> impl Stream<Item = ShortEventId> + Send + 'a { + self.db + .eventid_shorteventid + .get_batch(event_ids.iter()) + .enumerate() + .map(|(i, result)| match result { + Ok(ref short) => utils::u64_from_u8(short), + Err(_) => { + const BUFSIZE: usize = size_of::<ShortEventId>(); + + let short = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + + self.db + .eventid_shorteventid + .raw_aput::<BUFSIZE, _, _>(event_ids[i], short); + self.db + .shorteventid_eventid + .aput_raw::<BUFSIZE, _, _>(short, event_ids[i]); + + short + }, + }) +} + +#[implement(Service)] +pub async fn get_shorteventid(&self, event_id: &EventId) -> Result<ShortEventId> { + self.db + .eventid_shorteventid + .get(event_id) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> ShortStateKey { + const BUFSIZE: usize = size_of::<ShortStateKey>(); + + if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await { + return shortstatekey; } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.get_shortroomid(room_id) } + let key = (event_type, state_key); + let shortstatekey = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed"); + + self.db + .statekey_shortstatekey + .put_aput::<BUFSIZE, _, _>(key, shortstatekey); + + self.db + .shortstatekey_statekey + .aput_put::<BUFSIZE, _, _>(shortstatekey, key); + + shortstatekey +} - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { - self.db.get_or_create_shortroomid(room_id) +#[implement(Service)] +pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<ShortStateKey> { + let key = (event_type, state_key); + self.db + .statekey_shortstatekey + .qry(&key) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result<Arc<EventId>> { + const BUFSIZE: usize = size_of::<ShortEventId>(); + + self.db + .shorteventid_eventid + .aqry::<BUFSIZE, _>(&shorteventid) + .await + .deserialized() + .map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}"))) +} + +#[implement(Service)] +pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) -> Vec<Result<Arc<EventId>>> { + const BUFSIZE: usize = size_of::<ShortEventId>(); + + let keys: Vec<[u8; BUFSIZE]> = shorteventid + .iter() + .map(|short| short.to_be_bytes()) + .collect(); + + self.db + .shorteventid_eventid + .get_batch(keys.iter()) + .map(Deserialized::deserialized) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_statekey_from_short(&self, shortstatekey: ShortStateKey) -> Result<(StateEventType, String)> { + const BUFSIZE: usize = size_of::<ShortStateKey>(); + + self.db + .shortstatekey_statekey + .aqry::<BUFSIZE, _>(&shortstatekey) + .await + .deserialized() + .map_err(|e| { + err!(Database( + "Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}" + )) + }) +} + +/// Returns (shortstatehash, already_existed) +#[implement(Service)] +pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortStateHash, bool) { + const BUFSIZE: usize = size_of::<ShortStateHash>(); + + if let Ok(shortstatehash) = self + .db + .statehash_shortstatehash + .get(state_hash) + .await + .deserialized() + { + return (shortstatehash, true); } + + let shortstatehash = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&shortstatehash) == BUFSIZE, "buffer requirement changed"); + + self.db + .statehash_shortstatehash + .raw_aput::<BUFSIZE, _, _>(state_hash, shortstatehash); + + (shortstatehash, false) +} + +#[implement(Service)] +pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result<ShortRoomId> { + self.db.roomid_shortroomid.get(room_id).await.deserialized() +} + +#[implement(Service)] +pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId { + self.db + .roomid_shortroomid + .get(room_id) + .await + .deserialized() + .unwrap_or_else(|_| { + const BUFSIZE: usize = size_of::<ShortRoomId>(); + + let short = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + + self.db + .roomid_shortroomid + .raw_aput::<BUFSIZE, _, _>(room_id, short); + + short + }) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 24d612d875a969a1497eb4f70ff83d9c328079ff..0ef7ddf56f90c6c0ba513d63b6f6ac8bc946b606 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,12 @@ sync::Arc, }; -use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{ + checked, debug_info, err, + utils::{math::usize_from_f64, IterStream}, + Error, Result, +}; +use futures::{StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ api::{ @@ -28,7 +33,7 @@ }; use tokio::sync::Mutex; -use crate::{rooms, sending, Dep}; +use crate::{rooms, rooms::short::ShortRoomId, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -44,7 +49,7 @@ pub enum SummaryAccessibility { pub struct PaginationToken { /// Path down the hierarchy of the room to start the response at, /// excluding the root space. - pub short_room_ids: Vec<u64>, + pub short_room_ids: Vec<ShortRoomId>, pub limit: UInt, pub max_depth: UInt, pub suggested_only: bool, @@ -57,11 +62,11 @@ fn from_str(value: &str) -> Result<Self> { let mut values = value.split('_'); let mut pag_tok = || { - let mut rooms = vec![]; - - for room in values.next()?.split(',') { - rooms.push(u64::from_str(room).ok()?); - } + let rooms = values + .next()? + .split(',') + .filter_map(|room_s| u64::from_str(room_s).ok()) + .collect(); Some(Self { short_room_ids: rooms, @@ -211,12 +216,15 @@ async fn get_summary_and_children_local( .as_ref() { return Ok(if let Some(cached) = cached { - if self.is_accessible_child( - current_room, - &cached.summary.join_rule, - &identifier, - &cached.summary.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &cached.summary.join_rule, + &identifier, + &cached.summary.allowed_room_ids, + ) + .await + { Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) } else { Some(SummaryAccessibility::Inaccessible) @@ -226,25 +234,25 @@ async fn get_summary_and_children_local( }); } - Ok( - if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self.get_room_summary(current_room, children_pdus, &identifier); - if let Ok(summary) = summary { - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: summary.clone(), - }), - ); - - Some(SummaryAccessibility::Accessible(Box::new(summary))) - } else { - None - } + if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { + let summary = self + .get_room_summary(current_room, children_pdus, &identifier) + .await; + if let Ok(summary) = summary { + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), + }), + ); + + Ok(Some(SummaryAccessibility::Accessible(Box::new(summary)))) } else { - None - }, - ) + Ok(None) + } + } else { + Ok(None) + } } /// Gets the summary of a space using solely federation @@ -322,12 +330,15 @@ async fn get_summary_and_children_federation( ); } } - if self.is_accessible_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) { + if self + .is_accessible_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + ) + .await + { return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); } @@ -358,7 +369,7 @@ async fn get_summary_and_children_client( } } - fn get_room_summary( + async fn get_room_summary( &self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>, identifier: &Identifier<'_>, ) -> Result<SpaceHierarchyParentSummary, Error> { @@ -367,48 +378,38 @@ fn get_room_summary( let join_rule = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map_or(JoinRule::Invite, |c: RoomJoinRulesEventContent| c.join_rule); let allowed_room_ids = self .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { - debug!("User is not allowed to see room {room_id}"); + if !self + .is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) + .await + { + debug_info!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); } - let join_rule = join_rule.into(); - Ok(SpaceHierarchyParentSummary { canonical_alias: self .services .state_accessor .get_canonical_alias(room_id) - .unwrap_or(None), - name: self - .services - .state_accessor - .get_name(room_id) - .unwrap_or(None), + .await + .ok(), + name: self.services.state_accessor.get_name(room_id).await.ok(), num_joined_members: self .services .state_cache .room_joined_count(room_id) - .unwrap_or_default() - .unwrap_or_else(|| { - warn!("Room {room_id} has no member count"); - 0 - }) + .await + .unwrap_or(0) .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), @@ -416,25 +417,36 @@ fn get_room_summary( .services .state_accessor .get_room_topic(room_id) - .unwrap_or(None), - world_readable: self.services.state_accessor.is_world_readable(room_id)?, - guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + .await + .ok(), + world_readable: self + .services + .state_accessor + .is_world_readable(room_id) + .await, + guest_can_join: self.services.state_accessor.guest_can_join(room_id).await, avatar_url: self .services .state_accessor - .get_avatar(room_id)? + .get_avatar(room_id) + .await .into_option() .unwrap_or_default() .url, - join_rule, - room_type: self.services.state_accessor.get_room_type(room_id)?, + join_rule: join_rule.into(), + room_type: self + .services + .state_accessor + .get_room_type(room_id) + .await + .ok(), children_state, allowed_room_ids, }) } pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<u64>, max_depth: u64, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<ShortRoomId>, max_depth: u64, suggested_only: bool, ) -> Result<client::space::get_hierarchy::v1::Response> { let mut parents = VecDeque::new(); @@ -450,7 +462,7 @@ pub async fn get_client_hierarchy( }, )]]; - let mut results = Vec::new(); + let mut results = Vec::with_capacity(limit); while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { if results.len() >= limit { @@ -474,21 +486,22 @@ pub async fn get_client_hierarchy( results.push(summary_to_chunk(*summary.clone())); } else { children = children - .into_iter() - .rev() - .skip_while(|(room, _)| { - if let Ok(short) = self.services.short.get_shortroomid(room) - { - short.as_ref() != short_room_ids.get(parents.len()) - } else { - false - } - }) - .collect::<Vec<_>>() - // skip_while doesn't implement DoubleEndedIterator, which is needed for rev - .into_iter() - .rev() - .collect(); + .iter() + .rev() + .stream() + .skip_while(|(room, _)| { + self.services + .short + .get_shortroomid(room) + .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .unwrap_or_else(|_| false) + }) + .map(Clone::clone) + .collect::<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>() + .await + .into_iter() + .rev() + .collect(); if children.is_empty() { return Err(Error::BadRequest( @@ -528,11 +541,12 @@ pub async fn get_client_hierarchy( parents.pop_front(); parents.push_back(room); - let mut short_room_ids = vec![]; - - for room in parents { - short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); - } + let short_room_ids: Vec<_> = parents + .iter() + .stream() + .filter_map(|room_id| async move { self.services.short.get_shortroomid(room_id).await.ok() }) + .collect() + .await; Some( PaginationToken { @@ -554,7 +568,7 @@ pub async fn get_client_hierarchy( async fn get_stripped_space_child_events( &self, room_id: &RoomId, ) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { - let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else { return Ok(None); }; @@ -562,10 +576,13 @@ async fn get_stripped_space_child_events( .services .state_accessor .state_full_ids(current_shortstatehash) - .await?; - let mut children_pdus = Vec::new(); + .await + .map_err(|e| err!(Database("State in space not found: {e}")))?; + + let mut children_pdus = Vec::with_capacity(state.len()); for (key, id) in state { - let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?; + if event_type != StateEventType::SpaceChild { continue; } @@ -573,15 +590,14 @@ async fn get_stripped_space_child_events( let pdu = self .services .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + .get_pdu(&id) + .await + .map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?; - if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; + if let Ok(content) = pdu.get_content::<SpaceChildEventContent>() { + if content.via.is_empty() { + continue; + } } if OwnedRoomId::try_from(state_key).is_ok() { @@ -593,20 +609,18 @@ async fn get_stripped_space_child_events( } /// With the given identifier, checks if a room is accessable - fn is_accessible_child( + async fn is_accessible_child( &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec<OwnedRoomId>, ) -> bool { - // Note: unwrap_or_default for bool means false match identifier { Identifier::ServerName(server_name) => { - let room_id: &RoomId = current_room; - // Checks if ACLs allow for the server to participate if self .services .event_handler - .acl_check(server_name, room_id) + .acl_check(server_name, current_room) + .await .is_err() { return false; @@ -617,38 +631,28 @@ fn is_accessible_child( .services .state_cache .is_joined(user_id, current_room) - .unwrap_or_default() - || self - .services - .state_cache - .is_invited(user_id, current_room) - .unwrap_or_default() + .await || self + .services + .state_cache + .is_invited(user_id, current_room) + .await { return true; } }, - } // Takes care of join rules - match join_rule { + } + match &join_rule { + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, SpaceRoomJoinRule::Restricted => { for room in allowed_room_ids { match identifier { Identifier::UserId(user) => { - if self - .services - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { + if self.services.state_cache.is_joined(user, room).await { return true; } }, Identifier::ServerName(server) => { - if self - .services - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { + if self.services.state_cache.server_in_room(server, room).await { return true; } }, @@ -656,7 +660,6 @@ fn is_accessible_child( } false }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, // Invite only, Private, or Custom join rule _ => false, } diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs deleted file mode 100644 index 3c110afc633c39318b20405cf181ae9acbdf9b57..0000000000000000000000000000000000000000 --- a/src/service/rooms/state/data.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::{collections::HashSet, sync::Arc}; - -use conduit::{utils, Error, Result}; -use database::{Database, Map}; -use ruma::{EventId, OwnedEventId, RoomId}; - -use super::RoomMutexGuard; - -pub(super) struct Data { - shorteventid_shortstatehash: Arc<Map>, - roomid_pduleaves: Arc<Map>, - roomid_shortstatehash: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), - roomid_pduleaves: db["roomid_pduleaves"].clone(), - roomid_shortstatehash: db["roomid_shortstatehash"].clone(), - } - } - - pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } - - #[inline] - pub(super) fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) - } - - pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) - } - - pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - pub(super) fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec<OwnedEventId>, - _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - - for event_id in event_ids { - let mut key = prefix.clone(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } - - Ok(()) - } -} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index cb219bc038096e247d835647d8b93630ab5db83d..29ffedfce8f9698eecccd71d34cf5b6be9671d6e 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,18 +1,19 @@ -mod data; - use std::{ collections::{HashMap, HashSet}, fmt::Write, + iter::once, sync::Arc, }; use conduit::{ - utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, PduEvent, Result, + err, + result::FlatOk, + utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, + warn, PduEvent, Result, }; -use data::Data; +use database::{Deserialized, Ignore, Interfix, Map}; +use futures::{future::join_all, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ - api::client::error::ErrorKind, events::{ room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, AnyStrippedStateEvent, StateEventType, TimelineEventType, @@ -26,9 +27,9 @@ use crate::{globals, rooms, Dep}; pub struct Service { + pub mutex: RoomMutexMap, services: Services, db: Data, - pub mutex: RoomMutexMap, } struct Services { @@ -41,12 +42,19 @@ struct Services { timeline: Dep<rooms::timeline::Service>, } +struct Data { + shorteventid_shortstatehash: Arc<Map>, + roomid_shortstatehash: Arc<Map>, + roomid_pduleaves: Arc<Map>, +} + type RoomMutexMap = MutexMap<OwnedRoomId, ()>; pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + mutex: RoomMutexMap::new(), services: Services { globals: args.depend::<globals::Service>("globals"), short: args.depend::<rooms::short::Service>("rooms::short"), @@ -56,12 +64,15 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), }, - db: Data::new(args.db), - mutex: RoomMutexMap::new(), + db: Data { + shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(), + roomid_shortstatehash: args.db["roomid_shortstatehash"].clone(), + roomid_pduleaves: args.db["roomid_pduleaves"].clone(), + }, })) } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + fn memory_usage(&self, out: &mut dyn Write) -> Result { let mutex = self.mutex.len(); writeln!(out, "state_mutex: {mutex}")?; @@ -80,48 +91,34 @@ pub async fn force_state( statediffnew: Arc<HashSet<CompressedStateEvent>>, _statediffremoved: Arc<HashSet<CompressedStateEvent>>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - for event_id in statediffnew.iter().filter_map(|new| { + ) -> Result { + let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor - .parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { - continue; - }; + .parse_compressed_state_event(*new) + .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) + }); - let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), - ) { - Ok(pdu) => pdu, - Err(_) => continue, + pin_mut!(event_ids); + while let Some(event_id) = event_ids.next().await { + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; }; match pdu.kind { TimelineEventType::RoomMember => { - let Ok(membership_event) = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()) else { + let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok() else { continue; }; - let Some(state_key) = pdu.state_key else { + let Ok(membership_event) = pdu.get_content::<RoomMemberEventContent>() else { continue; }; - let Ok(user_id) = UserId::parse(state_key) else { - continue; - }; - - self.services.state_cache.update_membership( - room_id, - &user_id, - membership_event, - &pdu.sender, - None, - None, - false, - )?; + self.services + .state_cache + .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false) + .await?; }, TimelineEventType::SpaceChild => { self.services @@ -135,10 +132,9 @@ pub async fn force_state( } } - self.services.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id).await; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.set_room_state(room_id, shortstatehash, state_lock); Ok(()) } @@ -148,39 +144,45 @@ pub async fn force_state( /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] - pub fn set_event_state( + pub async fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, ) -> Result<u64> { - let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; + const BUFSIZE: usize = size_of::<u64>(); + + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(room_id).await; - let state_hash = calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::<Vec<_>>(), - ); + let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..])); let (shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if !already_existed { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&state_ids_compressed) .copied() .collect(); @@ -198,7 +200,9 @@ pub fn set_event_state( )?; } - self.db.set_event_state(shorteventid, shortstatehash)?; + self.db + .shorteventid_shortstatehash + .aput::<BUFSIZE, BUFSIZE, _, _>(shorteventid, shortstatehash); Ok(shortstatehash) } @@ -208,39 +212,49 @@ pub fn set_event_state( /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { + pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { + const BUFSIZE: usize = size_of::<u64>(); + let shorteventid = self .services .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + .get_or_create_shorteventid(&new_pdu.event_id) + .await; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await; - if let Some(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p)?; + if let Ok(p) = previous_shortstatehash { + self.db + .shorteventid_shortstatehash + .aput::<BUFSIZE, BUFSIZE, _, _>(shorteventid, p); } if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - #[inline] - |p| self.services.state_compressor.load_shortstatehash_info(p), - )?; + let states_parents = if let Ok(p) = previous_shortstatehash { + self.services + .state_compressor + .load_shortstatehash_info(p) + .await? + } else { + Vec::new() + }; let shortstatekey = self .services .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await; let new = self .services .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await; let replaces = states_parents .last() .map(|info| { - info.1 + info.full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) }) @@ -275,159 +289,159 @@ pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { } } - #[tracing::instrument(skip(self, invite_event), level = "debug")] - pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = + #[tracing::instrument(skip_all, level = "debug")] + pub async fn summary_stripped(&self, event: &PduEvent) -> Vec<Raw<AnyStrippedStateEvent>> { + let cells = [ + (&StateEventType::RoomCreate, ""), + (&StateEventType::RoomJoinRules, ""), + (&StateEventType::RoomCanonicalAlias, ""), + (&StateEventType::RoomName, ""), + (&StateEventType::RoomAvatar, ""), + (&StateEventType::RoomMember, event.sender.as_str()), // Add recommended events + (&StateEventType::RoomEncryption, ""), + (&StateEventType::RoomTopic, ""), + ]; + + let fetches = cells.iter().map(|(event_type, state_key)| { self.services .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.services - .state_accessor - .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = self.services.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } + .room_state_get(&event.room_id, event_type, state_key) + }); - state.push(invite_event.to_stripped_state_event()); - Ok(state) + join_all(fetches) + .await + .into_iter() + .filter_map(Result::ok) + .map(|e| e.to_stripped_state_event()) + .chain(once(event.to_stripped_state_event())) + .collect() } /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self, mutex_lock), level = "debug")] + #[tracing::instrument(skip(self, _mutex_lock), level = "debug")] pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, - mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex + ) { + const BUFSIZE: usize = size_of::<u64>(); + + self.db + .roomid_shortstatehash + .raw_aput::<BUFSIZE, _, _>(room_id, shortstatehash); } /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { - let create_event = self - .services + pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { + self.services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: RoomCreateEventContent = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; - - Ok(create_event_content.room_version) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.room_version) + .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) } - #[inline] - pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.db.get_room_shortstatehash(room_id) + pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<u64> { + self.db + .roomid_shortstatehash + .get(room_id) + .await + .deserialized() } - pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { - self.db.get_forward_extremities(room_id) + pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix(&prefix) + .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) + .ignore_err() } - pub fn set_forward_extremities( + pub async fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec<OwnedEventId>, - state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + _state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex + ) { + let prefix = (room_id, Interfix); self.db - .set_forward_extremities(room_id, event_ids, state_lock) + .roomid_pduleaves + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.roomid_pduleaves.remove(key)) + .await; + + for event_id in &event_ids { + let key = (room_id, event_id); + self.db.roomid_pduleaves.put_raw(key, event_id); + } } /// This fetches auth events from the current state. #[tracing::instrument(skip(self), level = "debug")] - pub fn get_auth_events( + pub async fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, ) -> Result<StateMap<Arc<PduEvent>>> { - let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else { + let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { return Ok(HashMap::new()); }; - let auth_events = - state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?; - let mut sauthevents = auth_events - .into_iter() + let mut sauthevents: HashMap<_, _> = auth_events + .iter() + .stream() .filter_map(|(event_type, state_key)| { self.services .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) + .get_shortstatekey(event_type, state_key) + .map_ok(move |s| (s, (event_type, state_key))) + .map(Result::ok) }) - .collect::<HashMap<_, _>>(); + .collect() + .await; let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| { + err!(Database( + "Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}" + )) + })? .pop() .expect("there is always one layer") - .1; + .full_state; - Ok(full_state - .iter() - .filter_map(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - }) - .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) - .filter_map(|(k, event_id)| { - self.services - .timeline - .get_pdu(&event_id) - .ok() - .flatten() - .map(|pdu| (k, pdu)) - }) - .collect()) + let mut ret = HashMap::new(); + for compressed in full_state.iter() { + let Ok((shortstatekey, event_id)) = self + .services + .state_compressor + .parse_compressed_state_event(*compressed) + .await + else { + continue; + }; + + let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else { + continue; + }; + + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; + }; + + ret.insert((ty.to_owned(), state_key.to_owned()), pdu); + } + + Ok(ret) } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 4c85148dbb6458b22a4f2a758f9c97a7be57eda6..06cd648cf1625f0f97672d86b24e70ff88b8239f 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,10 +1,11 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, PduEvent, Result}; -use database::Map; +use conduit::{err, PduEvent, Result}; +use database::{Deserialized, Map}; +use futures::TryFutureExt; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortStateHash, Dep}; pub(super) struct Data { eventid_shorteventid: Arc<Map>, @@ -35,21 +36,26 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } #[allow(unused_qualifications)] // async traits - pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { + pub(super) async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") - .1; + .full_state; + let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed)?; + .parse_compressed_state_event(*compressed) + .await?; + result.insert(parsed.0, parsed.1); i = i.wrapping_add(1); @@ -57,20 +63,22 @@ pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap tokio::task::yield_now().await; } } + Ok(result) } #[allow(unused_qualifications)] // async traits pub(super) async fn state_full( - &self, shortstatehash: u64, + &self, shortstatehash: ShortStateHash, ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await? .pop() .expect("there is always one layer") - .1; + .full_state; let mut result = HashMap::new(); let mut i: u8 = 0; @@ -78,18 +86,13 @@ pub(super) async fn state_full( let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); + .parse_compressed_state_event(*compressed) + .await?; + + if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { + if let Some(state_key) = pdu.state_key.as_ref() { + result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu); + } } i = i.wrapping_add(1); @@ -101,61 +104,63 @@ pub(super) async fn state_full( Ok(result) } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). #[allow(clippy::unused_self)] - pub(super) fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - let Some(shortstatekey) = self + pub(super) async fn state_get_id( + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, + ) -> Result<Arc<EventId>> { + let shortstatekey = self .services .short - .get_shortstatekey(event_type, state_key)? - else { - return Ok(None); - }; + .get_shortstatekey(event_type, state_key) + .await?; + let full_state = self .services .state_compressor - .load_shortstatehash_info(shortstatehash)? + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") - .1; - Ok(full_state + .full_state; + + let compressed = full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.services - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + .ok_or(err!(Database("No shortstatekey in compressed state")))?; + + self.services + .state_compressor + .parse_compressed_state_event(*compressed) + .map_ok(|(_, id)| id) + .map_err(|e| { + err!(Database(error!( + ?event_type, + ?state_key, + ?shortstatekey, + "Failed to parse compressed: {e:?}" + ))) + }) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn state_get( + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, + ) -> Result<Arc<PduEvent>> { + self.state_get_id(shortstatehash, event_type, state_key) + .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) + .await } /// Returns the state hash for this pdu. - pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { + pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> { self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") - }) - }) - .transpose() - }) + .get(event_id) + .and_then(|shorteventid| self.shorteventid_shortstatehash.get(&shorteventid)) + .await + .deserialized() } /// Returns the full room state. @@ -163,34 +168,33 @@ pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64 pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_full(shortstatehash)) + .map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get_id( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result<Arc<EventId>> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) + .await } - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub(super) fn room_state_get( + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub(super) async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + ) -> Result<Arc<PduEvent>> { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) + .await } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 58fa31b3d427470ab4119606a5140b3f44d428d2..4958c4eaf39aa065e9c8bc93134fca2facdd176e 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,8 +6,13 @@ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; -use data::Data; +use conduit::{ + err, error, + pdu::PduBuilder, + utils::{math::usize_from_f64, ReadyExt}, + Err, Error, Event, PduEvent, Result, +}; +use futures::StreamExt; use lru_cache::LruCache; use ruma::{ events::{ @@ -24,22 +29,27 @@ power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, topic::RoomTopicEventContent, }, - StateEventType, + StateEventType, TimelineEventType, }, room::RoomType, space::SpaceRoomJoinRule, - EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, - UserId, + EventEncryptionAlgorithm, EventId, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + ServerName, UserId, }; -use serde_json::value::to_raw_value; +use serde::Deserialize; -use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; +use self::data::Data; +use crate::{ + rooms, + rooms::{short::ShortStateHash, state::RoomMutexGuard}, + Dep, +}; pub struct Service { services: Services, db: Data, - pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>, - pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>, + pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, ShortStateHash), bool>>, + pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, ShortStateHash), bool>>, } struct Services { @@ -88,108 +98,114 @@ impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { + pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> { self.db.state_full_ids(shortstatehash).await } - pub async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { + pub async fn state_full( + &self, shortstatehash: ShortStateHash, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { self.db.state_full(shortstatehash).await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - self.db.state_get_id(shortstatehash, event_type, state_key) + pub async fn state_get_id( + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, + ) -> Result<Arc<EventId>> { + self.db + .state_get_id(shortstatehash, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[inline] - pub fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - self.db.state_get(shortstatehash, event_type, state_key) + pub async fn state_get( + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, + ) -> Result<Arc<PduEvent>> { + self.db + .state_get(shortstatehash, event_type, state_key) + .await + } + + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn state_get_content<T>( + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, + ) -> Result<T> + where + T: for<'de> Deserialize<'de> + Send, + { + self.state_get(shortstatehash, event_type, state_key) + .await + .and_then(|event| event.get_content()) } /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(MembershipState::Leave), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomMemberEventContent| c.membership) - .map_err(|_| Error::bad_database("Invalid room membership event in database.")) - }) + async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState { + self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) } /// The user was a joined member at this state (potentially in the past) #[inline] - fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join) - // Return sensible default, i.e. - // false + async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) - // Return sensible default, i.e. false + async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { + let s = self.user_membership(shortstatehash, user_id).await; + s == MembershipState::Join || s == MembershipState::Invite } /// Whether a server is allowed to see an event through federation, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result<bool> { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); + pub async fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; }; if let Some(visibility) = self .server_visibility_cache .lock() - .unwrap() + .expect("locked") .get_mut(&(origin.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - }) - .unwrap_or(HistoryVisibility::Shared); + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); - let mut current_server_members = self + let current_server_members = self .services .state_cache .room_members(room_id) - .filter_map(Result::ok) - .filter(|member| member.server_name() == origin); + .ready_filter(|member| member.server_name() == origin); let visibility = match history_visibility { HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_invited(shortstatehash, member)) + .await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + current_server_members + .any(|member| self.user_was_joined(shortstatehash, member)) + .await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -199,56 +215,48 @@ pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_ self.server_visibility_cache .lock() - .unwrap() + .expect("locked") .insert((origin.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result<bool> { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); + pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; }; if let Some(visibility) = self .user_visibility_cache .lock() - .unwrap() + .expect("locked") .get_mut(&(user_id.to_owned(), shortstatehash)) { - return Ok(*visibility); + return *visibility; } - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - }) - .unwrap_or(HistoryVisibility::Shared); + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); let visibility = match history_visibility { HistoryVisibility::WorldReadable => true, HistoryVisibility::Shared => currently_member, HistoryVisibility::Invited => { // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id) + self.user_was_invited(shortstatehash, user_id).await }, HistoryVisibility::Joined => { // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id) + self.user_was_joined(shortstatehash, user_id).await }, _ => { error!("Unknown history visibility {history_visibility}"); @@ -258,38 +266,34 @@ pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: & self.user_visibility_cache .lock() - .unwrap() + .expect("locked") .insert((user_id.to_owned(), shortstatehash), visibility); - Ok(visibility) + visibility } /// Whether a user is allowed to see an event, based on /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] - pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; + pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { + if self.services.state_cache.is_joined(user_id, room_id).await { + return true; + } let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|e| { - error!( - "Invalid history visibility event in database for room {room_id}, assuming is \"shared\": \ - {e}" - ); - Error::bad_database("Invalid history visibility event in database.") - }) - }) - .unwrap_or(HistoryVisibility::Shared); + .room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + history_visibility == HistoryVisibility::WorldReadable } /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { self.db.pdu_shortstatehash(event_id) } + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> { + self.db.pdu_shortstatehash(event_id).await + } /// Returns the full room state. #[tracing::instrument(skip(self), level = "debug")] @@ -300,180 +304,164 @@ pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEv /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get_id( + pub async fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - self.db.room_state_get_id(room_id, event_type, state_key) + ) -> Result<Arc<EventId>> { + self.db + .room_state_get_id(room_id, event_type, state_key) + .await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_get( + pub async fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - self.db.room_state_get(room_id, event_type, state_key) + ) -> Result<Arc<PduEvent>> { + self.db.room_state_get(room_id, event_type, state_key).await } - pub fn get_name(&self, room_id: &RoomId) -> Result<Option<String>> { - self.room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { - Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) - }) + /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). + pub async fn room_state_get_content<T>( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result<T> + where + T: for<'de> Deserialize<'de> + Send, + { + self.room_state_get(room_id, event_type, state_key) + .await + .and_then(|event| event.get_content()) } - pub fn get_avatar(&self, room_id: &RoomId) -> Result<ruma::JsOption<RoomAvatarEventContent>> { - self.room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(ruma::JsOption::Undefined), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }) + pub async fn get_name(&self, room_id: &RoomId) -> Result<String> { + self.room_state_get_content(room_id, &StateEventType::RoomName, "") + .await + .map(|c: RoomNameEventContent| c.name) } - pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<RoomMemberEventContent>> { - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }) + pub async fn get_avatar(&self, room_id: &RoomId) -> JsOption<RoomAvatarEventContent> { + let content = self + .room_state_get_content(room_id, &StateEventType::RoomAvatar, "") + .await + .ok(); + + JsOption::from_option(content) } - pub fn user_can_invite( - &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, - ) -> Result<bool> { - let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) - .expect("Event content always serializes"); - - let new_event = PduBuilder { - event_type: ruma::events::TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(target_user.into()), - redacts: None, - timestamp: None, - }; + pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> { + self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str()) + .await + } - Ok(self - .services + pub async fn user_can_invite( + &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, + ) -> bool { + self.services .timeline - .create_hash_and_sign_event(new_event, sender, room_id, state_lock) - .is_ok()) + .create_hash_and_sign_event( + PduBuilder::state(target_user.into(), &RoomMemberEventContent::new(MembershipState::Invite)), + sender, + room_id, + state_lock, + ) + .await + .is_ok() } /// Checks if guests are able to view room content without joining - pub fn is_world_readable(&self, room_id: &RoomId) -> Result<bool, Error> { - self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|e| { - error!( - "Invalid room history visibility event in database for room {room_id}, assuming not world \ - readable: {e} " - ); - Error::bad_database("Invalid room history visibility event in database.") - }) - }) + pub async fn is_world_readable(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) + .unwrap_or(false) } /// Checks if guests are able to join a given room - pub fn guest_can_join(&self, room_id: &RoomId) -> Result<bool, Error> { - self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) - .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - }) + pub async fn guest_can_join(&self, room_id: &RoomId) -> bool { + self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "") + .await + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .unwrap_or(false) } /// Gets the primary alias from canonical alias event - pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result<Option<OwnedRoomAliasId>, Error> { - self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result<OwnedRoomAliasId> { + self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "") + .await + .and_then(|c: RoomCanonicalAliasEventContent| { + c.alias + .ok_or_else(|| err!(Request(NotFound("No alias found in event content.")))) }) } /// Gets the room topic - pub fn get_room_topic(&self, room_id: &RoomId) -> Result<Option<String>, Error> { - self.room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|e| { - error!("Invalid room topic event in database for room {room_id}: {e}"); - Error::bad_database("Invalid room topic event in database.") - }) - }) + pub async fn get_room_topic(&self, room_id: &RoomId) -> Result<String> { + self.room_state_get_content(room_id, &StateEventType::RoomTopic, "") + .await + .map(|c: RoomTopicEventContent| c.topic) } /// Checks if a given user can redact a given event /// /// If federation is true, it allows redaction events from any user of the /// same server as the original event sender - pub fn user_can_redact( + pub async fn user_can_redact( &self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool, ) -> Result<bool> { - self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map_or_else( - || { - // Falling back on m.room.create to judge power level - if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { - Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - pdu.sender == sender - } else { - false - }) + let redacting_event = self.services.timeline.get_pdu(redacts).await; + + if redacting_event + .as_ref() + .is_ok_and(|event| event.event_type() == &TimelineEventType::RoomCreate) + { + return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding."))); + } + + if let Ok(pl_event_content) = self + .room_state_get_content::<RoomPowerLevelsEventContent>(room_id, &StateEventType::RoomPowerLevels, "") + .await + { + let pl_event: RoomPowerLevels = pl_event_content.into(); + Ok(pl_event.user_can_redact_event_of_other(sender) + || pl_event.user_can_redact_own_event(sender) + && if let Ok(redacting_event) = redacting_event { + if federation { + redacting_event.sender.server_name() == sender.server_name() + } else { + redacting_event.sender == sender + } } else { - Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", - )) - } - }, - |event| { - serde_json::from_str(event.content.get()) - .map(|content: RoomPowerLevelsEventContent| content.into()) - .map(|event: RoomPowerLevels| { - event.user_can_redact_event_of_other(sender) - || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { - if federation { - pdu.sender.server_name() == sender.server_name() - } else { - pdu.sender == sender - } - } else { - false - } - }) - .map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database")) - }, - ) + false + }) + } else { + // Falling back on m.room.create to judge power level + if let Ok(room_create) = self + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + { + Ok(room_create.sender == sender + || redacting_event + .as_ref() + .is_ok_and(|redacting_event| redacting_event.sender == sender)) + } else { + Err(Error::bad_database( + "No m.room.power_levels or m.room.create events in database for room", + )) + } + } } /// Returns the join rule (`SpaceRoomJoinRule`) for a given room - pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>), Error> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| { - (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) - }) - .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) - }) - .transpose()? - .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) + pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> { + self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))) + .or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![]))) } /// Returns an empty vec if not a restricted room pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec<OwnedRoomId> { - let mut room_ids = vec![]; + let mut room_ids = Vec::with_capacity(1); if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule { for rule in r.allow { if let AllowRule::RoomMembership(RoomMembership { @@ -487,25 +475,27 @@ pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec<OwnedRoomId> { room_ids } - pub fn get_room_type(&self, room_id: &RoomId) -> Result<Option<RoomType>> { - Ok(self - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::<RoomCreateEventContent>(s.content.get()) - .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) + pub async fn get_room_type(&self, room_id: &RoomId) -> Result<RoomType> { + self.room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .and_then(|content: RoomCreateEventContent| { + content + .room_type + .ok_or_else(|| err!(Request(NotFound("No type found in event content")))) }) - .transpose()? - .and_then(|e| e.room_type)) } /// Gets the room's encryption algorithm if `m.room.encryption` state event /// is found - pub fn get_room_encryption(&self, room_id: &RoomId) -> Result<Option<EventEncryptionAlgorithm>> { - self.room_state_get(room_id, &StateEventType::RoomEncryption, "")? - .map_or(Ok(None), |s| { - serde_json::from_str::<RoomEncryptionEventContent>(s.content.get()) - .map(|content| Some(content.algorithm)) - .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) - }) + pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> { + self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "") + .await + .map(|content: RoomEncryptionEventContent| content.algorithm) + } + + pub async fn is_encrypted_room(&self, room_id: &RoomId) -> bool { + self.room_state_get(room_id, &StateEventType::RoomEncryption, "") + .await + .is_ok() } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs deleted file mode 100644 index 19c73ea1ca546967401138378058461226bd5568..0000000000000000000000000000000000000000 --- a/src/service/rooms/state_cache/data.rs +++ /dev/null @@ -1,666 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, RwLock}, -}; - -use conduit::{utils, Error, Result}; -use database::Map; -use itertools::Itertools; -use ruma::{ - events::{AnyStrippedStateEvent, AnySyncStateEvent}, - serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, -}; - -use crate::{appservice::RegistrationInfo, globals, users, Dep}; - -type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; -type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; -type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>; - -pub(super) struct Data { - pub(super) appservice_in_room_cache: AppServiceInRoomCache, - roomid_invitedcount: Arc<Map>, - roomid_inviteviaservers: Arc<Map>, - roomid_joinedcount: Arc<Map>, - roomserverids: Arc<Map>, - roomuserid_invitecount: Arc<Map>, - roomuserid_joined: Arc<Map>, - roomuserid_leftcount: Arc<Map>, - roomuseroncejoinedids: Arc<Map>, - serverroomids: Arc<Map>, - userroomid_invitestate: Arc<Map>, - userroomid_joined: Arc<Map>, - userroomid_leftstate: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, - users: Dep<users::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - appservice_in_room_cache: RwLock::new(HashMap::new()), - roomid_invitedcount: db["roomid_invitedcount"].clone(), - roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(), - roomid_joinedcount: db["roomid_joinedcount"].clone(), - roomserverids: db["roomserverids"].clone(), - roomuserid_invitecount: db["roomuserid_invitecount"].clone(), - roomuserid_joined: db["roomuserid_joined"].clone(), - roomuserid_leftcount: db["roomuserid_leftcount"].clone(), - roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), - serverroomids: db["serverroomids"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - users: args.depend::<users::Service>("users"), - }, - } - } - - pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) - } - - pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let roomid = room_id.as_bytes().to_vec(); - - let mut roomuser_id = roomid.clone(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) - } - - pub(super) fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, - invite_via: Option<Vec<OwnedServerName>>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - if let Some(servers) = invite_via { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - #[allow(clippy::redundant_clone)] // this is a necessary clone? - prev_servers.append(servers.clone().as_mut()); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - } - - Ok(()) - } - - pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let roomid = room_id.as_bytes().to_vec(); - - let mut roomuser_id = roomid.clone(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), - )?; // TODO - self.roomuserid_leftcount - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - - self.roomid_inviteviaservers.remove(&roomid)?; - - Ok(()) - } - - pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(Result::ok) { - joined_servers.insert(joined.server_name().to_owned()); - joinedcount = joinedcount.saturating_add(1); - } - - for _invited in self.room_members_invited(room_id).filter_map(Result::ok) { - invitedcount = invitedcount.saturating_add(1); - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xFF); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xFF); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.registration.id)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let bridge_user_id = UserId::parse_with_server_name( - appservice.registration.sender_localpart.as_str(), - self.services.globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self - .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.registration.id.clone(), in_room); - - Ok(in_room) - } - } - - /// Makes a user forget a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_servers<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { - let mut key = server.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn server_rooms<'a>( - &'a self, server: &ServerName, - ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + Send + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - /// Returns an iterator of all our local users in the room, even if they're - /// deactivated/guests - pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { - Box::new( - self.room_members(room_id) - .filter_map(Result::ok) - .filter(|user| self.services.globals.user_is_local(user)), - ) - } - - /// Returns an iterator of all our local joined users in a room who are - /// active (not deactivated, not guest) - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn active_local_users_in_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { - Box::new( - self.local_users_in_room(room_id) - .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), - ) - } - - /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_useroncejoined<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn room_members_invited<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) - } - - /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn invite_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) - }) - .transpose() - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn left_state( - &self, user_id: &UserId, room_id: &RoomId, - ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() - } - - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn servers_invite_via<'a>( - &'a self, room_id: &RoomId, - ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { - let key = room_id.as_bytes().to_vec(); - - Box::new( - self.roomid_inviteviaservers - .scan_prefix(key) - .map(|(_, servers)| { - ServerName::parse( - utils::string_from_bytes( - servers - .rsplit(|&b| b == 0xFF) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self - .servers_invite_via(room_id) - .filter_map(Result::ok) - .collect_vec(); - prev_servers.extend(servers.to_owned()); - prev_servers.sort_unstable(); - prev_servers.dedup(); - - let servers = prev_servers - .iter() - .map(|server| server.as_bytes()) - .collect_vec() - .join(&[0xFF][..]); - - self.roomid_inviteviaservers - .insert(room_id.as_bytes(), &servers)?; - - Ok(()) - } -} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 71899ceb986a5bee71adb3e46197997afda29d3a..6e330fdc1e68f97926e179a27fc0c09a26de0ad8 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,14 +1,20 @@ -mod data; - -use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, RwLock}, +}; -use conduit::{err, error, warn, Error, Result}; -use data::Data; +use conduit::{ + err, is_not_empty, + result::LogErr, + utils::{stream::TryIgnore, ReadyExt, StreamTools}, + warn, Result, +}; +use database::{serialize_to_vec, Deserialized, Ignore, Interfix, Json, Map}; +use futures::{future::join4, stream::iter, Stream, StreamExt}; use itertools::Itertools; use ruma::{ events::{ direct::DirectEvent, - ignored_user_list::IgnoredUserListEvent, room::{ create::RoomCreateEventContent, member::{MembershipState, RoomMemberEventContent}, @@ -18,12 +24,13 @@ }, int, serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; pub struct Service { + appservice_in_room_cache: AppServiceInRoomCache, services: Services, db: Data, } @@ -35,16 +42,49 @@ struct Services { users: Dep<users::Service>, } +struct Data { + roomid_invitedcount: Arc<Map>, + roomid_inviteviaservers: Arc<Map>, + roomid_joinedcount: Arc<Map>, + roomserverids: Arc<Map>, + roomuserid_invitecount: Arc<Map>, + roomuserid_joined: Arc<Map>, + roomuserid_leftcount: Arc<Map>, + roomuseroncejoinedids: Arc<Map>, + serverroomids: Arc<Map>, + userroomid_invitestate: Arc<Map>, + userroomid_joined: Arc<Map>, + userroomid_leftstate: Arc<Map>, +} + +type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>; +type StrippedStateEventItem = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>); +type SyncStateEventItem = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>); + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + appservice_in_room_cache: RwLock::new(HashMap::new()), services: Services { account_data: args.depend::<account_data::Service>("account_data"), globals: args.depend::<globals::Service>("globals"), state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), users: args.depend::<users::Service>("users"), }, - db: Data::new(&args), + db: Data { + roomid_invitedcount: args.db["roomid_invitedcount"].clone(), + roomid_inviteviaservers: args.db["roomid_inviteviaservers"].clone(), + roomid_joinedcount: args.db["roomid_joinedcount"].clone(), + roomserverids: args.db["roomserverids"].clone(), + roomuserid_invitecount: args.db["roomuserid_invitecount"].clone(), + roomuserid_joined: args.db["roomuserid_joined"].clone(), + roomuserid_leftcount: args.db["roomuserid_leftcount"].clone(), + roomuseroncejoinedids: args.db["roomuseroncejoinedids"].clone(), + serverroomids: args.db["serverroomids"].clone(), + userroomid_invitestate: args.db["userroomid_invitestate"].clone(), + userroomid_joined: args.db["userroomid_joined"].clone(), + userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + }, })) } @@ -55,7 +95,7 @@ impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] - pub fn update_membership( + pub async fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>, update_joined_count: bool, @@ -68,7 +108,7 @@ pub fn update_membership( // update #[allow(clippy::collapsible_if)] if !self.services.globals.user_is_local(user_id) { - if !self.services.users.exists(user_id)? { + if !self.services.users.exists(user_id).await { self.services.users.create(user_id, None)?; } @@ -100,17 +140,17 @@ pub fn update_membership( match &membership { MembershipState::Join => { // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { + if !self.once_joined(user_id, room_id).await { // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id)?; + self.mark_as_once_joined(user_id, room_id); // Check if the room has a predecessor - if let Some(predecessor) = self + if let Ok(Some(predecessor)) = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .map(|content: RoomCreateEventContent| content.predecessor) { // Copy user settings from predecessor to the current room: // - Push rules @@ -138,32 +178,27 @@ pub fn update_membership( // .ok(); // Copy old tags to new room - if let Some(tag_event) = self + if let Ok(tag_event) = self .services .account_data - .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? - .map(|event| { - serde_json::from_str(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) { + .get_room(&predecessor.room_id, user_id, RoomAccountDataEventType::Tag) + .await + { self.services .account_data - .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) + .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event) + .await .ok(); }; // Copy direct chat flag - if let Some(direct_event) = self + if let Ok(mut direct_event) = self .services .account_data - .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? - .map(|event| { - serde_json::from_str::<DirectEvent>(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) { - let mut direct_event = direct_event?; + .get_global::<DirectEvent>(user_id, GlobalAccountDataEventType::Direct) + .await + { let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { if room_ids.iter().any(|r| r == &predecessor.room_id) { room_ids.push(room_id.to_owned()); @@ -172,236 +207,392 @@ pub fn update_membership( } if room_ids_updated { - self.services.account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event).expect("to json always works"), - )?; + self.services + .account_data + .update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event).expect("to json always works"), + ) + .await?; } }; } } - self.db.mark_as_joined(user_id, room_id)?; + self.mark_as_joined(user_id, room_id); }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = self - .services - .account_data - .get( - None, // Ignored users are in global account data - user_id, // Receiver - GlobalAccountDataEventType::IgnoredUserList - .to_string() - .into(), - )? - .map(|event| { - serde_json::from_str::<IgnoredUserListEvent>(event.get()) - .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) - }) - .transpose()? - .map_or(false, |ignored| { - ignored - .content - .ignored_users - .iter() - .any(|(user, _details)| user == sender) - }); - - if is_ignored { + if self.services.users.user_is_ignored(sender, user_id).await { return Ok(()); } - self.db - .mark_as_invited(user_id, room_id, last_state, invite_via)?; + self.mark_as_invited(user_id, room_id, last_state, invite_via) + .await; }, MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; + self.mark_as_left(user_id, room_id); }, _ => {}, } if update_joined_count { - self.update_joined_count(room_id)?; + self.update_joined_count(room_id).await; } Ok(()) } - #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] - pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { - self.db.appservice_in_room(room_id, appservice) + pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool { + if let Some(cached) = self + .appservice_in_room_cache + .read() + .expect("locked") + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied() + { + return cached; + } + + let bridge_user_id = UserId::parse_with_server_name( + appservice.registration.sender_localpart.as_str(), + self.services.globals.server_name(), + ); + + let Ok(bridge_user_id) = bridge_user_id.log_err() else { + return false; + }; + + let in_room = self.is_joined(&bridge_user_id, room_id).await + || self + .room_members(room_id) + .ready_any(|user_id| appservice.users.is_match(user_id.as_str())) + .await; + + self.appservice_in_room_cache + .write() + .expect("locked") + .entry(room_id.into()) + .or_default() + .insert(appservice.registration.id.clone(), in_room); + + in_room } - /// Direct DB function to directly mark a user as left. It is not + /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_left(user_id, room_id) + pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + self.db.userroomid_joined.insert(&userroom_id, []); + self.db.roomuserid_joined.insert(&roomuser_id, []); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); } - /// Direct DB function to directly mark a user as joined. It is not + /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead #[tracing::instrument(skip(self), level = "debug")] - pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.mark_as_joined(user_id, room_id) + pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + // (timo) TODO + let leftstate = Vec::<Raw<AnySyncStateEvent>>::new(); + let count = self.services.globals.next_count().unwrap(); + + self.db + .userroomid_leftstate + .raw_put(&userroom_id, Json(leftstate)); + self.db.roomuserid_leftcount.raw_put(&roomuser_id, count); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_invitestate.remove(&userroom_id); + self.db.roomuserid_invitecount.remove(&roomuser_id); + + self.db.roomid_inviteviaservers.remove(room_id); } /// Makes a user forget a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { + let userroom_id = (user_id, room_id); + let roomuser_id = (room_id, user_id); + + self.db.userroomid_leftstate.del(userroom_id); + self.db.roomuserid_leftcount.del(roomuser_id); + } /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { - self.db.room_servers(room_id) + pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomserverids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, server): (Ignore, &ServerName)| server) } #[tracing::instrument(skip(self), level = "debug")] - pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { - self.db.server_in_room(server, room_id) + pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool { + let key = (server, room_id); + self.db.serverroomids.qry(&key).await.is_ok() } /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self), level = "debug")] - pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { - self.db.server_rooms(server) + pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a { + let prefix = (server, Interfix); + self.db + .serverroomids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns true if server can see user by sharing at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result<bool> { - Ok(self - .server_rooms(server) - .filter_map(Result::ok) - .any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false))) + pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool { + self.server_rooms(server) + .any(|room_id| self.is_joined(user_id, room_id)) + .await } /// Returns true if user_a and user_b share at least one room. #[tracing::instrument(skip(self), level = "debug")] - pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> { + pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool { // Minimize number of point-queries by iterating user with least nr rooms - let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { + let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await { (user_a, user_b) } else { (user_b, user_a) }; - Ok(self - .rooms_joined(a) - .filter_map(Result::ok) - .any(|room_id| self.is_joined(b, &room_id).unwrap_or(false))) + self.rooms_joined(a) + .any(|room_id| self.is_joined(b, room_id)) + .await } - /// Returns an iterator over all joined members of a room. + /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + Send + '_ { - self.db.room_members(room_id) + pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_joined + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_joined_count(room_id) } + pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> { + self.db.roomid_joinedcount.get(room_id).await.deserialized() + } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests - pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a { - self.db.local_users_in_room(room_id) + pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a { + self.room_members(room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) } #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a { - self.db.active_local_users_in_room(room_id) + pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a { + self.local_users_in_room(room_id) + .filter(|user| self.services.users.is_active(user)) } /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self), level = "debug")] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_invited_count(room_id) } + pub async fn room_invited_count(&self, room_id: &RoomId) -> Result<u64> { + self.db + .roomid_invitedcount + .get(room_id) + .await + .deserialized() + } /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { - self.db.room_useroncejoined(room_id) + pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuseroncejoinedids + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self), level = "debug")] - pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { - self.db.room_members_invited(room_id) + pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a { + let prefix = (room_id, Interfix); + self.db + .roomuserid_invitecount + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, user_id): (Ignore, &UserId)| user_id) } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - self.db.get_invite_count(room_id, user_id) + pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> { + let key = (room_id, user_id); + self.db + .roomuserid_invitecount + .qry(&key) + .await + .deserialized() } #[tracing::instrument(skip(self), level = "debug")] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - self.db.get_left_count(room_id, user_id) + pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> { + let key = (room_id, user_id); + self.db.roomuserid_leftcount.qry(&key).await.deserialized() } /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { - self.db.rooms_joined(user_id) + pub fn rooms_joined<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &RoomId> + Send + 'a { + self.db + .userroomid_joined + .keys_raw_prefix(user_id) + .ignore_err() + .map(|(_, room_id): (Ignore, &RoomId)| room_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_invited( - &self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + '_ { - self.db.rooms_invited(user_id) + pub fn rooms_invited<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw<Vec<AnyStrippedStateEvent>>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_invitestate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() } #[tracing::instrument(skip(self), level = "debug")] - pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - self.db.invite_state(user_id, room_id) + pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { + let key = (user_id, room_id); + self.db + .userroomid_invitestate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into)) } #[tracing::instrument(skip(self), level = "debug")] - pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - self.db.left_state(user_id, room_id) + pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { + let key = (user_id, room_id); + self.db + .userroomid_leftstate + .qry(&key) + .await + .deserialized() + .and_then(|val: Raw<Vec<AnyStrippedStateEvent>>| val.deserialize_as().map_err(Into::into)) } /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self), level = "debug")] - pub fn rooms_left( - &self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + '_ { - self.db.rooms_left(user_id) + pub fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a { + type KeyVal<'a> = (Key<'a>, Raw<Vec<Raw<AnySyncStateEvent>>>); + type Key<'a> = (&'a UserId, &'a RoomId); + + let prefix = (user_id, Interfix); + self.db + .userroomid_leftstate + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, room_id), state): KeyVal<'_>| (room_id.to_owned(), state)) + .map(|(room_id, state)| Ok((room_id, state.deserialize_as()?))) + .ignore_err() } #[tracing::instrument(skip(self), level = "debug")] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - self.db.once_joined(user_id, room_id) + pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_joined(user_id, room_id) } + pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_joined.qry(&key).await.is_ok() + } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - self.db.is_invited(user_id, room_id) + pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_invitestate.qry(&key).await.is_ok() } #[tracing::instrument(skip(self), level = "debug")] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) } + pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool { + let key = (user_id, room_id); + self.db.userroomid_leftstate.qry(&key).await.is_ok() + } + + pub async fn user_membership(&self, user_id: &UserId, room_id: &RoomId) -> Option<MembershipState> { + let states = join4( + self.is_joined(user_id, room_id), + self.is_left(user_id, room_id), + self.is_invited(user_id, room_id), + self.once_joined(user_id, room_id), + ) + .await; + + match states { + (true, ..) => Some(MembershipState::Join), + (_, true, ..) => Some(MembershipState::Leave), + (_, _, true, ..) => Some(MembershipState::Invite), + (false, false, false, true) => Some(MembershipState::Ban), + _ => None, + } + } #[tracing::instrument(skip(self), level = "debug")] - pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { - self.db.servers_invite_via(room_id) + pub fn servers_invite_via<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a { + type KeyVal<'a> = (Ignore, Vec<&'a ServerName>); + + self.db + .roomid_inviteviaservers + .stream_raw_prefix(room_id) + .ignore_err() + .map(|(_, servers): KeyVal<'_>| *servers.last().expect("at least one server")) } /// Gets up to three servers that are likely to be in the room in the @@ -409,44 +600,32 @@ pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Resul /// /// See <https://spec.matrix.org/v1.10/appendices/#routing> #[tracing::instrument(skip(self))] - pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> { + pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> { let most_powerful_user_server = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|pdu| { - serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| { - conent - .users - .iter() - .max_by_key(|(_, power)| *power) - .and_then(|x| { - if x.1 >= &int!(50) { - Some(x) - } else { - None - } - }) - .map(|(user, _power)| user.server_name().to_owned()) - }) + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await + .map(|content: RoomPowerLevelsEventContent| { + content + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| (x.1 >= &int!(50)).then_some(x)) + .map(|(user, _power)| user.server_name().to_owned()) }) - .transpose() - .map_err(|e| { - error!("Invalid power levels event content in database: {e}"); - Error::bad_database("Invalid power levels event content in database") - })? - .flatten(); + .map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?; let mut servers: Vec<OwnedServerName> = self .room_members(room_id) - .filter_map(Result::ok) .counts_by(|user| user.server_name().to_owned()) - .iter() + .await + .into_iter() .sorted_by_key(|(_, users)| *users) - .map(|(server, _)| server.to_owned()) + .map(|(server, _)| server) .rev() .take(3) - .collect_vec(); + .collect(); if let Some(server) = most_powerful_user_server { servers.insert(0, server); @@ -457,15 +636,125 @@ pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName> } pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) { - let cache = self.db.appservice_in_room_cache.read().expect("locked"); + let cache = self.appservice_in_room_cache.read().expect("locked"); + (cache.len(), cache.capacity()) } pub fn clear_appservice_in_room_cache(&self) { - self.db - .appservice_in_room_cache + self.appservice_in_room_cache .write() .expect("locked") .clear(); } + + pub async fn update_joined_count(&self, room_id: &RoomId) { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + self.room_members(room_id) + .ready_for_each(|joined| { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount = joinedcount.saturating_add(1); + }) + .await; + + invitedcount = invitedcount.saturating_add( + self.room_members_invited(room_id) + .count() + .await + .try_into() + .unwrap_or(0), + ); + + self.db.roomid_joinedcount.raw_put(room_id, joinedcount); + self.db.roomid_invitedcount.raw_put(room_id, invitedcount); + + self.room_servers(room_id) + .ready_for_each(|old_joined_server| { + if joined_servers.remove(old_joined_server) { + return; + } + + // Server not in room anymore + let roomserver_id = (room_id, old_joined_server); + let serverroom_id = (old_joined_server, room_id); + + self.db.roomserverids.del(roomserver_id); + self.db.serverroomids.del(serverroom_id); + }) + .await; + + // Now only new servers are in joined_servers anymore + for server in &joined_servers { + let roomserver_id = (room_id, server); + let serverroom_id = (server, room_id); + + self.db.roomserverids.put_raw(roomserver_id, []); + self.db.serverroomids.put_raw(serverroom_id, []); + } + + self.appservice_in_room_cache + .write() + .expect("locked") + .remove(room_id); + } + + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) { + let key = (user_id, room_id); + self.db.roomuseroncejoinedids.put_raw(key, []); + } + + pub async fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, + invite_via: Option<Vec<OwnedServerName>>, + ) { + let roomuser_id = (room_id, user_id); + let roomuser_id = serialize_to_vec(roomuser_id).expect("failed to serialize roomuser_id"); + + let userroom_id = (user_id, room_id); + let userroom_id = serialize_to_vec(userroom_id).expect("failed to serialize userroom_id"); + + self.db + .userroomid_invitestate + .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); + + self.db + .roomuserid_invitecount + .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + + self.db.userroomid_joined.remove(&userroom_id); + self.db.roomuserid_joined.remove(&roomuser_id); + + self.db.userroomid_leftstate.remove(&userroom_id); + self.db.roomuserid_leftcount.remove(&roomuser_id); + + if let Some(servers) = invite_via.filter(is_not_empty!()) { + self.add_servers_invite_via(room_id, servers).await; + } + } + + #[tracing::instrument(skip(self, servers), level = "debug")] + pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: Vec<OwnedServerName>) { + let mut servers: Vec<_> = self + .servers_invite_via(room_id) + .map(ToOwned::to_owned) + .chain(iter(servers.into_iter())) + .collect() + .await; + + servers.sort_unstable(); + servers.dedup(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.db + .roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers); + } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs deleted file mode 100644 index 337730019ac34b6a5fdc196bb1c5a35c053e38ab..0000000000000000000000000000000000000000 --- a/src/service/rooms/state_compressor/data.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; - -use conduit::{checked, utils, Error, Result}; -use database::{Database, Map}; - -use super::CompressedStateEvent; - -pub(super) struct StateDiff { - pub(super) parent: Option<u64>, - pub(super) added: Arc<HashSet<CompressedStateEvent>>, - pub(super) removed: Arc<HashSet<CompressedStateEvent>>, -} - -pub(super) struct Data { - shortstatehash_statediff: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - shortstatehash_statediff: db["shortstatehash_statediff"].clone(), - } - } - - pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { - let value = self - .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); - let parent = if parent != 0 { - Some(parent) - } else { - None - }; - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let stride = size_of::<u64>(); - let mut i = stride; - while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i = checked!(i + stride)?; - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i = checked!(i + 2 * stride)?; - } - - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) - } - - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } - - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) - } -} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 2550774e11e3fa6b5f78b028b20ccf4b41385db9..0466fb125ccb925dc16a00273fff94cc9cf69039 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,55 +1,30 @@ -mod data; - use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, fmt::Write, mem::size_of, - sync::{Arc, Mutex as StdMutex, Mutex}, + sync::{Arc, Mutex}, }; -use conduit::{checked, utils, utils::math::usize_from_f64, Result}; -use data::Data; +use arrayvec::ArrayVec; +use conduit::{ + at, checked, debug, err, expected, utils, + utils::{bytes, math::usize_from_f64}, + Result, +}; +use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use self::data::StateDiff; -use crate::{rooms, Dep}; - -type StateInfoLruCache = Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - Arc<HashSet<CompressedStateEvent>>, // full state - Arc<HashSet<CompressedStateEvent>>, // added - Arc<HashSet<CompressedStateEvent>>, // removed - )>, - >, ->; - -type ShortStateInfoResult = Result< - Vec<( - u64, // sstatehash - Arc<HashSet<CompressedStateEvent>>, // full state - Arc<HashSet<CompressedStateEvent>>, // added - Arc<HashSet<CompressedStateEvent>>, // removed - )>, ->; - -type ParentStatesVec = Vec<( - u64, // sstatehash - Arc<HashSet<CompressedStateEvent>>, // full state - Arc<HashSet<CompressedStateEvent>>, // added - Arc<HashSet<CompressedStateEvent>>, // removed -)>; - -type HashSetCompressStateEvent = Result<(u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>)>; -pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()]; +use crate::{ + rooms, + rooms::short::{ShortId, ShortStateHash, ShortStateKey}, + Dep, +}; pub struct Service { + pub stateinfo_cache: Mutex<StateInfoLruCache>, db: Data, services: Services, - pub stateinfo_cache: StateInfoLruCache, } struct Services { @@ -57,23 +32,77 @@ struct Services { state: Dep<rooms::state::Service>, } +struct Data { + shortstatehash_statediff: Arc<Map>, +} + +#[derive(Clone)] +struct StateDiff { + parent: Option<ShortStateHash>, + added: Arc<CompressedState>, + removed: Arc<CompressedState>, +} + +#[derive(Clone, Default)] +pub struct ShortStateInfo { + pub shortstatehash: ShortStateHash, + pub full_state: Arc<CompressedState>, + pub added: Arc<CompressedState>, + pub removed: Arc<CompressedState>, +} + +#[derive(Clone, Default)] +pub struct HashSetCompressStateEvent { + pub shortstatehash: ShortStateHash, + pub added: Arc<CompressedState>, + pub removed: Arc<CompressedState>, +} + +type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>; +type ShortStateInfoVec = Vec<ShortStateInfo>; +type ParentStatesVec = Vec<ShortStateInfo>; + +pub(crate) type CompressedState = HashSet<CompressedStateEvent>; +pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()]; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { - db: Data::new(args.db), + stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(), + db: Data { + shortstatehash_statediff: args.db["shortstatehash_statediff"].clone(), + }, services: Services { short: args.depend::<rooms::short::Service>("rooms::short"), state: args.depend::<rooms::state::Service>("rooms::state"), }, - stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let stateinfo_cache = self.stateinfo_cache.lock().expect("locked").len(); - writeln!(out, "stateinfo_cache: {stateinfo_cache}")?; + fn memory_usage(&self, out: &mut dyn Write) -> Result { + let (cache_len, ents) = { + let cache = self.stateinfo_cache.lock().expect("locked"); + let ents = cache + .iter() + .map(at!(1)) + .flat_map(|vec| vec.iter()) + .fold(HashMap::new(), |mut ents, ssi| { + ents.insert(Arc::as_ptr(&ssi.added), compressed_state_size(&ssi.added)); + ents.insert(Arc::as_ptr(&ssi.removed), compressed_state_size(&ssi.removed)); + ents.insert(Arc::as_ptr(&ssi.full_state), compressed_state_size(&ssi.full_state)); + ents + }); + + (cache.len(), ents) + }; + + let ents_len = ents.len(); + let bytes = ents.values().copied().fold(0_usize, usize::saturating_add); + + let bytes = bytes::pretty(bytes); + writeln!(out, "stateinfo_cache: {cache_len} {ents_len} ({bytes})")?; Ok(()) } @@ -86,12 +115,11 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self), level = "debug")] - pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + pub async fn load_shortstatehash_info(&self, shortstatehash: ShortStateHash) -> Result<ShortStateInfoVec> { if let Some(r) = self .stateinfo_cache .lock() - .unwrap() + .expect("locked") .get_mut(&shortstatehash) { return Ok(r.clone()); @@ -101,56 +129,82 @@ pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoRes parent, added, removed, - } = self.db.get_statediff(shortstatehash)?; + } = self.get_statediff(shortstatehash).await?; - if let Some(parent) = parent { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = (*response.last().unwrap().1).clone(); + let response = if let Some(parent) = parent { + let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; + let mut state = (*response.last().expect("at least one response").full_state).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { state.remove(r); } - response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); - - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); + response.push(ShortStateInfo { + shortstatehash, + full_state: Arc::new(state), + added, + removed: Arc::new(removed), + }); - Ok(response) + response } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); - Ok(response) - } - } + vec![ShortStateInfo { + shortstatehash, + full_state: added.clone(), + added, + removed, + }] + }; - pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> { - let mut v = shortstatekey.to_be_bytes().to_vec(); - v.extend_from_slice( - &self - .services - .short - .get_or_create_shorteventid(event_id)? - .to_be_bytes(), + debug!( + ?parent, + ?shortstatehash, + vec_len = %response.len(), + "cache update" ); - Ok(v.try_into().expect("we checked the size above")) + + self.stateinfo_cache + .lock() + .expect("locked") + .insert(shortstatehash, response.clone()); + + Ok(response) + } + + pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent { + const SIZE: usize = size_of::<CompressedStateEvent>(); + + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; + + let mut v = ArrayVec::<u8, SIZE>::new(); + v.extend(shortstatekey.to_be_bytes()); + v.extend(shorteventid.to_be_bytes()); + v.as_ref() + .try_into() + .expect("failed to create CompressedStateEvent") } /// Returns shortstatekey, event id #[inline] - pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"), - self.services.short.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"), - )?, - )) + pub async fn parse_compressed_state_event( + &self, compressed_event: CompressedStateEvent, + ) -> Result<(ShortStateKey, Arc<EventId>)> { + use utils::u64_from_u8; + + let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<ShortStateKey>()]); + let shorteventid = u64_from_u8(&compressed_event[size_of::<ShortStateKey>()..]); + let event_id = self + .services + .short + .get_eventid_from_short(shorteventid) + .await?; + + Ok((shortstatekey, event_id)) } /// Creates a new shortstatehash that often is just a diff to an already @@ -171,12 +225,11 @@ pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEve /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states), level = "debug")] pub fn save_state_from_diff( - &self, shortstatehash: u64, statediffnew: Arc<HashSet<CompressedStateEvent>>, + &self, shortstatehash: ShortStateHash, statediffnew: Arc<HashSet<CompressedStateEvent>>, statediffremoved: Arc<HashSet<CompressedStateEvent>>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, - ) -> Result<()> { + ) -> Result { let statediffnew_len = statediffnew.len(); let statediffremoved_len = statediffremoved.len(); let diffsum = checked!(statediffnew_len + statediffremoved_len)?; @@ -186,8 +239,8 @@ pub fn save_state_from_diff( // To many layers, we have to go deeper let parent = parent_states.pop().expect("parent must have a state"); - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -220,14 +273,14 @@ pub fn save_state_from_diff( if parent_states.is_empty() { // There is no parent layer, create a new state - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { parent: None, added: statediffnew, removed: statediffremoved, }, - )?; + ); return Ok(()); }; @@ -237,14 +290,14 @@ pub fn save_state_from_diff( // 2. We replace a layer above let parent = parent_states.pop().expect("parent must have a state"); - let parent_2_len = parent.2.len(); - let parent_3_len = parent.3.len(); - let parent_diff = checked!(parent_2_len + parent_3_len)?; + let parent_added_len = parent.added.len(); + let parent_removed_len = parent.removed.len(); + let parent_diff = checked!(parent_added_len + parent_removed_len)?; if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -273,14 +326,14 @@ pub fn save_state_from_diff( )?; } else { // Diff small enough, we add diff as layer on top of parent - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { - parent: Some(parent.0), + parent: Some(parent.shortstatehash), added: statediffnew, removed: statediffremoved, }, - )?; + ); } Ok(()) @@ -288,38 +341,46 @@ pub fn save_state_from_diff( /// Returns the new shortstatehash, and the state diff from the previous /// room state - pub fn save_state( + #[tracing::instrument(skip(self, new_state_ids_compressed), level = "debug")] + pub async fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, - ) -> HashSetCompressStateEvent { - let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; + ) -> Result<HashSetCompressStateEvent> { + let previous_shortstatehash = self + .services + .state + .get_room_shortstatehash(room_id) + .await + .ok(); - let state_hash = utils::calculate_hash( - &new_state_ids_compressed - .iter() - .map(|bytes| &bytes[..]) - .collect::<Vec<_>>(), - ); + let state_hash = utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..])); let (new_shortstatehash, already_existed) = self .services .short - .get_or_create_shortstatehash(&state_hash)?; + .get_or_create_shortstatehash(&state_hash) + .await; if Some(new_shortstatehash) == previous_shortstatehash { - return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); + return Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + ..Default::default() + }); } - let states_parents = - previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = if let Some(p) = previous_shortstatehash { + self.load_shortstatehash_info(p).await.unwrap_or_default() + } else { + ShortStateInfoVec::new() + }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&new_state_ids_compressed) .copied() .collect(); @@ -339,6 +400,91 @@ pub fn save_state( )?; }; - Ok((new_shortstatehash, statediffnew, statediffremoved)) + Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + added: statediffnew, + removed: statediffremoved, + }) + } + + #[tracing::instrument(skip(self), level = "debug", name = "get")] + async fn get_statediff(&self, shortstatehash: ShortStateHash) -> Result<StateDiff> { + const BUFSIZE: usize = size_of::<ShortStateHash>(); + const STRIDE: usize = size_of::<ShortStateHash>(); + + let value = self + .db + .shortstatehash_statediff + .aqry::<BUFSIZE, _>(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + + let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]) + .ok() + .take_if(|parent| *parent != 0); + + debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); + let num_values = value.len() / STRIDE; + + let mut add_mode = true; + let mut added = HashSet::with_capacity(num_values); + let mut removed = HashSet::with_capacity(num_values); + + let mut i = STRIDE; + while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i = expected!(i + STRIDE); + continue; + } + if add_mode { + added.insert(v.try_into()?); + } else { + removed.insert(v.try_into()?); + } + i = expected!(i + 2 * STRIDE); + } + + added.shrink_to_fit(); + removed.shrink_to_fit(); + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) } + + fn save_statediff(&self, shortstatehash: ShortStateHash, diff: &StateDiff) { + let mut value = Vec::<u8>::with_capacity( + 2_usize + .saturating_add(diff.added.len()) + .saturating_add(diff.removed.len()), + ); + + let parent = diff.parent.unwrap_or(0_u64); + value.extend_from_slice(&parent.to_be_bytes()); + + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.db + .shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value); + } +} + +#[inline] +fn compressed_state_size(compressed_state: &CompressedState) -> usize { + compressed_state + .len() + .checked_mul(size_of::<CompressedStateEvent>()) + .expect("CompressedState size overflow") } diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs deleted file mode 100644 index fb279a007b064af171f696d8a93dd5933fa2d04a..0000000000000000000000000000000000000000 --- a/src/service/rooms/threads/data.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::{mem::size_of, sync::Arc}; - -use conduit::{checked, utils, Error, PduEvent, Result}; -use database::Map; -use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; - -use crate::{rooms, Dep}; - -type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>; - -pub(super) struct Data { - threadid_userids: Arc<Map>, - services: Services, -} - -struct Services { - short: Dep<rooms::short::Service>, - timeline: Dep<rooms::timeline::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - threadid_userids: db["threadid_userids"].clone(), - services: Services { - short: args.depend::<rooms::short::Service>("rooms::short"), - timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), - }, - } - } - - pub(super) fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { - let prefix = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); - - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = self - .services - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) - } - - pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { - let users = participants - .iter() - .map(|user| user.as_bytes()) - .collect::<Vec<_>>() - .join(&[0xFF][..]); - - self.threadid_userids.insert(root_id, &users)?; - - Ok(()) - } - - pub(super) fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xFF) - .map(|bytes| { - UserId::parse( - utils::string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(Result::ok) - .collect(), - )) - } else { - Ok(None) - } - } -} diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index ae51cd0f98533584e42be6288bfc6b3ae6015e0d..5821f2795ae19c4f6b3b601e80935245ef7551f6 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -1,34 +1,44 @@ -mod data; - use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, PduEvent, Result}; -use data::Data; +use conduit::{ + err, + utils::{stream::TryIgnore, ReadyExt}, + PduCount, PduEvent, PduId, RawPduId, Result, +}; +use database::{Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ - api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, - events::relation::BundledThread, - uint, CanonicalJsonValue, EventId, RoomId, UserId, + api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue, + EventId, OwnedUserId, RoomId, UserId, }; use serde_json::json; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortRoomId, Dep}; pub struct Service { - services: Services, db: Data, + services: Services, } struct Services { + short: Dep<rooms::short::Service>, timeline: Dep<rooms::timeline::Service>, } +pub(super) struct Data { + threadid_userids: Arc<Map>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + db: Data { + threadid_userids: args.db["threadid_userids"].clone(), + }, services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), }, - db: Data::new(&args), })) } @@ -36,30 +46,27 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result<impl Iterator<Item = Result<(u64, PduEvent)>> + 'a> { - self.db.threads_until(user_id, room_id, until, include) - } - - pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { let root_id = self .services .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; + .get_pdu_id(root_event_id) + .await + .map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?; let root_pdu = self .services .timeline - .get_pdu_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?; let mut root_pdu_json = self .services .timeline - .get_pdu_json_from_id(&root_id)? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; + .get_pdu_json_from_id(&root_id) + .await + .map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json .entry("unsigned".to_owned()) @@ -103,17 +110,66 @@ pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<( self.services .timeline - .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu) + .await?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(&root_id)? { + if let Ok(userids) = self.get_participants(&root_id).await { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); } users.push(pdu.sender.clone()); - self.db.update_participants(&root_id, &users) + self.update_participants(&root_id, &users) + } + + pub async fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads, + ) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> { + let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?; + + let current: RawPduId = PduId { + shortroomid, + shorteventid: shorteventid.saturating_sub(1), + } + .into(); + + let stream = self + .db + .threadid_userids + .rev_raw_keys_from(¤t) + .ignore_err() + .map(RawPduId::from) + .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) + .filter_map(move |pdu_id| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + let pdu_id: PduId = pdu_id.into(); + + if pdu.sender != user_id { + pdu.remove_transaction_id().ok(); + } + + Some((pdu_id.shorteventid, pdu)) + }); + + Ok(stream) + } + + pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result { + let users = participants + .iter() + .map(|user| user.as_bytes()) + .collect::<Vec<_>>() + .join(&[0xFF][..]); + + self.db.threadid_userids.insert(root_id, &users); + + Ok(()) + } + + pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> { + self.db.threadid_userids.get(root_id).await.deserialized() } } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 2f0c8f25878c8676b53cbefb8261a3fd4452be7c..22a6c1d0d9ca67b95783511c82dd3bb9c928262b 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,14 +1,23 @@ use std::{ + borrow::Borrow, collections::{hash_map, HashMap}, - mem::size_of, - sync::{Arc, Mutex}, + sync::Arc, }; -use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; -use database::{Database, Map}; -use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::{ + at, err, + result::{LogErr, NotFound}, + utils, + utils::{future::TryExtExt, stream::TryIgnore, ReadyExt}, + Err, PduCount, PduEvent, Result, +}; +use database::{Database, Deserialized, Json, KeyVal, Map}; +use futures::{Stream, StreamExt}; +use ruma::{api::Direction, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tokio::sync::Mutex; -use crate::{rooms, Dep}; +use super::{PduId, RawPduId}; +use crate::{rooms, rooms::short::ShortRoomId, Dep}; pub(super) struct Data { eventid_outlierpdu: Arc<Map>, @@ -25,8 +34,7 @@ struct Services { short: Dep<rooms::short::Service>, } -type PdusIterItem = Result<(PduCount, PduEvent)>; -type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>; +pub type PdusIterItem = (PduCount, PduEvent); type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>; impl Data { @@ -46,297 +54,242 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { + pub(super) async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> { match self .lasttimelinecount_cache .lock() - .expect("locked") - .entry(room_id.to_owned()) + .await + .entry(room_id.into()) { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { - Ok(*v.insert(last_count.0)) - } else { - Ok(PduCount::Normal(0)) - } - }, hash_map::Entry::Occupied(o) => Ok(*o.get()), + hash_map::Entry::Vacant(v) => Ok(self + .pdus_rev(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + .map(at!(0)) + .filter(|&count| matches!(count, PduCount::Normal(_))) + .map_or_else(PduCount::max, |count| *v.insert(count))), } } /// Returns the `count` of this pdu's id. - pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() + pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> { + self.get_pdu_id(event_id) + .await + .map(|pdu_id| pdu_id.pdu_count()) } /// Returns the json of a pdu. - pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - ) + pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> { + if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await { + return Ok(pdu); + } + + self.eventid_outlierpdu.get(event_id).await.deserialized() } /// Returns the json of a pdu. - pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await.deserialized() } /// Returns the pdu's id. #[inline] - pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<database::Handle<'_>>> { - self.eventid_pduid.get(event_id.as_bytes()) + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { + self.eventid_pduid + .get(event_id) + .await + .map(|handle| RawPduId::from(&*handle)) } /// Returns the pdu directly from `eventid_pduid` only. - pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() + pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await.deserialized() + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result { + let pduid = self.get_pdu_id(event_id).await?; + + self.pduid_pdu.get(&pduid).await.map(|_| ()) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub(super) fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - Ok(Some(pdu)) - } else { - Ok(None) + pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> { + self.get_pdu_owned(event_id).await.map(Arc::new) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub(super) async fn get_pdu_owned(&self, event_id: &EventId) -> Result<PduEvent> { + if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await { + return Ok(pdu); } + + self.eventid_outlierpdu.get(event_id).await.deserialized() + } + + /// Like get_non_outlier_pdu(), but without the expense of fetching and + /// parsing the PduEvent + pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result { + self.eventid_outlierpdu.get(event_id).await.map(|_| ()) + } + + /// Like get_pdu(), but without the expense of fetching and parsing the data + pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool { + let non_outlier = self.non_outlier_pdu_exists(event_id).is_ok(); + let outlier = self.outlier_pdu_exists(event_id).is_ok(); + + //TODO: parallelize + non_outlier.await || outlier.await } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { + self.pduid_pdu.get(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. - pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> { + self.pduid_pdu.get(pdu_id).await.deserialized() } - pub(super) fn append_pdu( - &self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, - ) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + pub(super) async fn append_pdu( + &self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount, + ) { + debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal"); + self.pduid_pdu.raw_put(pdu_id, Json(json)); self.lasttimelinecount_cache .lock() - .expect("locked") - .insert(pdu.room_id.clone(), PduCount::Normal(count)); - - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + .await + .insert(pdu.room_id.clone(), count); - Ok(()) + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu( - &self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject, - ) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; - - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; - - Ok(()) + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) { + self.pduid_pdu.raw_put(pdu_id, Json(json)); + self.eventid_pduid.insert(event_id, pdu_id); + self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. - pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + pub(super) async fn replace_pdu( + &self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result { + if self.pduid_pdu.get(pdu_id).await.is_not_found() { + return Err!(Request(NotFound("PDU does not exist."))); } + self.pduid_pdu.raw_put(pdu_id, Json(pdu_json)); + Ok(()) } /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::<PduEvent>(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus_rev<'a>( + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: PduCount, + ) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> { + let current = self + .count_to_id(room_id, until, Direction::Backward) + .await?; + let prefix = current.shortroomid(); + let stream = self + .pduid_pdu + .rev_raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) } - pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; - - let user_id = user_id.to_owned(); - - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::<PduEvent>(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) + pub(super) async fn pdus<'a>( + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: PduCount, + ) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> { + let current = self.count_to_id(room_id, from, Direction::Forward).await?; + let prefix = current.shortroomid(); + let stream = self + .pduid_pdu + .raw_stream_from(¤t) + .ignore_err() + .ready_take_while(move |(key, _)| key.starts_with(&prefix)) + .map(move |item| Self::each_pdu(item, user_id)); + + Ok(stream) + } + + fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: Option<&UserId>) -> PdusIterItem { + let pdu_id: RawPduId = pdu_id.into(); + + let mut pdu = + serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); + + if Some(pdu.sender.borrow()) != user_id { + pdu.remove_transaction_id().log_err().ok(); + } + + pdu.add_age().log_err().ok(); + + (pdu_id.pdu_count(), pdu) } pub(super) fn increment_notification_counts( &self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); + ) { + let _cork = self.db.cork(); + for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); + increment(&self.userroomid_notificationcount, &userroom_id); } + for user in highlights { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); + increment(&self.userroomid_highlightcount, &userroom_id); } - - self.userroomid_notificationcount - .increment_batch(notifies_batch.iter().map(Vec::as_slice))?; - self.userroomid_highlightcount - .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; - Ok(()) } - pub(super) fn count_to_id( - &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, - ) -> Result<(Vec<u8>, Vec<u8>)> { - let prefix = self + async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount, dir: Direction) -> Result<RawPduId> { + let shortroomid: ShortRoomId = self .services .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? - .to_be_bytes() - .to_vec(); - let mut pdu_id = prefix.clone(); + .get_shortroomid(room_id) + .await + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?; + // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } - }, + let pdu_id = PduId { + shortroomid, + shorteventid: shorteventid.saturating_inc(dir), }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - Ok((prefix, pdu_id)) + Ok(pdu_id.into()) } } -/// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { - let stride = size_of::<u64>(); - let pdu_id_len = pdu_id.len(); - let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = - utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); - - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) - } else { - Ok(PduCount::Normal(last_u64)) - } +//TODO: this is an ABA +fn increment(db: &Arc<Map>, key: &[u8]) { + let old = db.get_blocking(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, new); } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 4f2352f81ceb7d9916716acc7981c912abfbc17a..59fc8e93001a3584a3099ce9885f9cf959199d92 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,21 +1,24 @@ mod data; use std::{ + cmp, collections::{BTreeMap, HashSet}, fmt::Write, + iter::once, sync::Arc, }; use conduit::{ - debug, error, info, + debug, err, error, implement, info, pdu::{EventHash, PduBuilder, PduCount, PduEvent}, utils, - utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, Server, + utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, + validated, warn, Err, Error, Result, Server, }; -use itertools::Itertools; +pub use conduit::{PduId, RawPduId}; +use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt}; use ruma::{ - api::{client::error::ErrorKind, federation}, + api::federation, canonical_json::to_canonical_value, events::{ push_rules::PushRulesEvent, @@ -29,19 +32,21 @@ GlobalAccountDataEventType, StateEventType, TimelineEventType, }, push::{Action, Ruleset, Tweak}, - serde::Base64, state_res::{self, Event, RoomVersion}, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::RwLock; use self::data::Data; +pub use self::data::PdusIterItem; use crate::{ - account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::state_compressor::CompressedStateEvent, sending, server_keys, Dep, + account_data, admin, appservice, + appservice::NamespaceRegex, + globals, pusher, rooms, + rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent}, + sending, server_keys, users, Dep, }; // Update Relationships @@ -88,6 +93,7 @@ struct Services { sending: Dep<sending::Service>, server_keys: Dep<server_keys::Service>, user: Dep<rooms::user::Service>, + users: Dep<users::Service>, pusher: Dep<pusher::Service>, threads: Dep<rooms::threads::Service>, search: Dep<rooms::search::Service>, @@ -117,6 +123,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { sending: args.depend::<sending::Service>("sending"), server_keys: args.depend::<server_keys::Service>("server_keys"), user: args.depend::<rooms::user::Service>("rooms::user"), + users: args.depend::<users::Service>("users"), pusher: args.depend::<pusher::Service>("pusher"), threads: args.depend::<rooms::threads::Service>("rooms::threads"), search: args.depend::<rooms::search::Service>("rooms::search"), @@ -129,6 +136,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + /* let lasttimelinecount_cache = self .db .lasttimelinecount_cache @@ -136,6 +144,7 @@ fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { .expect("locked") .len(); writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + */ let mutex_insert = self.mutex_insert.len(); writeln!(out, "insert_mutex: {mutex_insert}")?; @@ -144,11 +153,13 @@ fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { } fn clear_cache(&self) { + /* self.db .lasttimelinecount_cache .lock() .expect("locked") .clear(); + */ } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -156,28 +167,32 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { #[tracing::instrument(skip(self), level = "debug")] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + pub async fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Arc<PduEvent>> { + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { - self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? - .last() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() + pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<Arc<PduEvent>> { + self.pdus_rev(None, room_id, None) + .await? + .next() + .await + .map(|(_, p)| Arc::new(p)) + .ok_or_else(|| err!(Request(NotFound("No PDU found in room")))) } #[tracing::instrument(skip(self), level = "debug")] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { - self.db.last_timeline_count(sender_user, room_id) + pub async fn last_timeline_count(&self, sender_user: Option<&UserId>, room_id: &RoomId) -> Result<PduCount> { + self.db.last_timeline_count(sender_user, room_id).await } /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> { self.db.get_pdu_count(event_id) } + pub async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> { self.db.get_pdu_count(event_id).await } // TODO Is this the same as the function above? /* @@ -203,49 +218,59 @@ pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { */ /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.db.get_pdu_json(event_id) + pub async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> { + self.db.get_pdu_json(event_id).await } /// Returns the json of a pdu. #[inline] - pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.db.get_non_outlier_pdu_json(event_id) + pub async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> { + self.db.get_non_outlier_pdu_json(event_id).await } /// Returns the pdu's id. #[inline] - pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<database::Handle<'_>>> { - self.db.get_pdu_id(event_id) - } + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. #[inline] - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { - self.db.get_non_outlier_pdu(event_id) + pub async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> { + self.db.get_non_outlier_pdu(event_id).await } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { self.db.get_pdu(event_id) } + pub async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> { self.db.get_pdu(event_id).await } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub async fn get_pdu_owned(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu_owned(event_id).await } + + /// Checks if pdu exists + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn pdu_exists<'a>(&'a self, event_id: &'a EventId) -> impl Future<Output = bool> + Send + 'a { + self.db.pdu_exists(event_id) + } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { self.db.get_pdu_from_id(pdu_id) } + pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { - self.db.get_pdu_json_from_id(pdu_id) + pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> { + self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { - self.db.replace_pdu(pdu_id, pdu_json, pdu) + pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu_json, pdu).await } /// Creates a new persisted data unit and adds it to a room. @@ -261,15 +286,16 @@ pub async fn append_pdu( mut pdu_json: CanonicalJsonObject, leaves: Vec<OwnedEventId>, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<Vec<u8>> { + ) -> Result<RawPduId> { // Coalesce database writes for the remainder of this scope. let _cork = self.db.db.cork_and_flush(); let shortroomid = self .services .short - .get_shortroomid(&pdu.room_id)? - .expect("room exists"); + .get_shortroomid(&pdu.room_id) + .await + .map_err(|_| err!(Database("Room does not exist")))?; // Make unsigned fields correct. This is not properly documented in the spec, // but state events need to have previous content in the unsigned field, so @@ -279,17 +305,17 @@ pub async fn append_pdu( .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = self + if let Ok(shortstatehash) = self .services .state_accessor .pdu_shortstatehash(&pdu.event_id) - .unwrap() + .await { - if let Some(prev_state) = self + if let Ok(prev_state) = self .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() + .await { unsigned.insert( "prev_content".to_owned(), @@ -318,10 +344,12 @@ pub async fn append_pdu( // We must keep track of all events that have been referenced. self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, leaves, state_lock) + .await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; @@ -330,17 +358,20 @@ pub async fn append_pdu( // appending fails self.services .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + .private_read_set(&pdu.room_id, &pdu.sender, count1); self.services .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); + let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: count2, + } + .into(); // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; drop(insert_lock); @@ -348,12 +379,9 @@ pub async fn append_pdu( let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? + .room_state_get_content(&pdu.room_id, &StateEventType::RoomPowerLevels, "") + .await + .map_err(|_| err!(Database("invalid m.room.power_levels event"))) .unwrap_or_default(); let sync_pdu = pdu.to_sync_room_event(); @@ -361,18 +389,20 @@ pub async fn append_pdu( let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = self + let mut push_target: HashSet<_> = self .services .state_cache .active_local_users_in_room(&pdu.room_id) - .collect_vec(); + .map(ToOwned::to_owned) + .collect() + .await; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { - let target_user_id = UserId::parse(state_key.clone()).expect("This state_key was previously validated"); + let target_user_id = UserId::parse(state_key.clone())?; - if !push_target.contains(&target_user_id) { - push_target.push(target_user_id); + if self.services.users.is_active_local(&target_user_id).await { + push_target.insert(target_user_id); } } } @@ -386,23 +416,18 @@ pub async fn append_pdu( let rules_for_user = self .services .account_data - .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? - .map(|event| { - serde_json::from_str::<PushRulesEvent>(event.get()).map_err(|e| { - warn!("Invalid push rules event in db for user ID {user}: {e}"); - Error::bad_database("Invalid push rules event in db.") - }) - }) - .transpose()? - .map_or_else(|| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); + .get_global(user, GlobalAccountDataEventType::PushRules) + .await + .map_or_else(|_| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); let mut highlight = false; let mut notify = false; - for action in - self.services - .pusher - .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + for action in self + .services + .pusher + .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id) + .await? { match action { Action::Notify => notify = true, @@ -421,49 +446,49 @@ pub async fn append_pdu( highlights.push(user.clone()); } - for push_key in self.services.pusher.get_pushkeys(user) { - self.services - .sending - .send_pdu_push(&pdu_id, user, push_key?)?; - } + self.services + .pusher + .get_pushkeys(user) + .ready_for_each(|push_key| { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key.to_owned()) + .expect("TODO: replace with future"); + }) + .await; } self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + .increment_notification_counts(&pdu.room_id, notifies, highlights); match pdu.kind { TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, _ => { - let content = - serde_json::from_str::<RoomRedactionEventContent>(pdu.content.get()).map_err(|e| { - warn!("Invalid content in redaction pdu: {e}"); - Error::bad_database("Invalid content in redaction pdu") - })?; - + let content: RoomRedactionEventContent = pdu.get_content()?; if let Some(redact_id) = &content.redacts { - if self.services.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { - self.redact_pdu(redact_id, pdu, shortroomid)?; + if self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? + { + self.redact_pdu(redact_id, pdu, shortroomid).await?; } } }, @@ -485,40 +510,32 @@ pub async fn append_pdu( let target_user_id = UserId::parse(state_key.clone()).expect("This state_key was previously validated"); - let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()).map_err(|e| { - error!("Invalid room member event content in pdu: {e}"); - Error::bad_database("Invalid room member event content in pdu.") - })?; - + let content: RoomMemberEventContent = pdu.get_content()?; let invite_state = match content.membership { - MembershipState::Invite => { - let state = self.services.state.calculate_invite_state(pdu)?; - Some(state) - }, + MembershipState::Invite => self.services.state.summary_stripped(pdu).await.into(), _ => None, }; // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - self.services.state_cache.update_membership( - &pdu.room_id, - &target_user_id, - content, - &pdu.sender, - invite_state, - None, - true, - )?; + self.services + .state_cache + .update_membership( + &pdu.room_id, + &target_user_id, + content, + &pdu.sender, + invite_state, + None, + true, + ) + .await?; } }, TimelineEventType::RoomMessage => { - let content = serde_json::from_str::<ExtractBody>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - + let content: ExtractBody = pdu.get_content()?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); if self.services.admin.is_admin_command(pdu, &body).await { self.services @@ -530,29 +547,32 @@ pub async fn append_pdu( _ => {}, } - if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) { - if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { + if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() { + if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(count2, related_pducount); } } - if let Ok(content) = serde_json::from_str::<ExtractRelatesTo>(pdu.content.get()) { + if let Ok(content) = pdu.get_content::<ExtractRelatesTo>() { match content.relates_to { Relation::Reply { in_reply_to, } => { // We need to do it again here, because replies don't have // event_id as a top level field - if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { + if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; + .add_relation(count2, related_pducount); } }, Relation::Thread(thread) => { - self.services.threads.add_to_thread(&thread.event_id, pdu)?; + self.services + .threads + .add_to_thread(&thread.event_id, pdu) + .await?; }, _ => {}, // TODO: Aggregate other types } @@ -562,11 +582,12 @@ pub async fn append_pdu( if self .services .state_cache - .appservice_in_room(&pdu.room_id, appservice)? + .appservice_in_room(&pdu.room_id, appservice) + .await { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } @@ -582,7 +603,7 @@ pub async fn append_pdu( if state_key_uid == appservice_uid { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } } @@ -596,28 +617,27 @@ pub async fn append_pdu( .as_ref() .map_or(false, |state_key| users.is_match(state_key)) }; - let matching_aliases = |aliases: &NamespaceRegex| { + let matching_aliases = |aliases: NamespaceRegex| { self.services .alias .local_aliases_for_room(&pdu.room_id) - .filter_map(Result::ok) - .any(|room_alias| aliases.is_match(room_alias.as_str())) + .ready_any(move |room_alias| aliases.is_match(room_alias.as_str())) }; - if matching_aliases(&appservice.aliases) + if matching_aliases(appservice.aliases.clone()).await || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; } } Ok(pdu_id) } - pub fn create_hash_and_sign_event( + pub async fn create_hash_and_sign_event( &self, pdu_builder: PduBuilder, sender: &UserId, @@ -636,52 +656,60 @@ pub fn create_hash_and_sign_event( let prev_events: Vec<_> = self .services .state - .get_forward_extremities(room_id)? - .into_iter() + .get_forward_extremities(room_id) .take(20) - .collect(); + .map(Arc::from) + .collect() + .await; // If there was no create event yet, assume we are creating a room - let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::<RoomCreateEventContent>(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self + .services + .state + .get_room_version(room_id) + .await + .or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content: RoomCreateEventContent = serde_json::from_str(content.get())?; + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - let auth_events = - self.services - .state - .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + let auth_events = self + .services + .state + .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content) + .await?; // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) + .stream() + .map(Ok) + .and_then(|event_id| self.get_pdu(event_id)) + .and_then(|pdu| future::ok(pdu.depth)) + .ignore_err() + .ready_fold(uint!(0), cmp::max) + .await .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.services - .state_accessor - .room_state_get(room_id, &event_type.to_string().into(), state_key)? + if let Ok(prev_pdu) = self + .services + .state_accessor + .room_state_get(room_id, &event_type.to_string().into(), state_key) + .await { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); + unsigned.insert("prev_content".to_owned(), prev_pdu.get_content_as_value()); unsigned.insert( "prev_sender".to_owned(), serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), @@ -727,19 +755,22 @@ pub fn create_hash_and_sign_event( signatures: None, }; + let auth_fetch = |k: &StateEventType, s: &str| { + let key = (k.clone(), s.to_owned()); + ready(auth_events.get(&key)) + }; + let auth_check = state_res::auth_check( &room_version, &pdu, - None::<PduEvent>, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), + None, // TODO: third_party_invite + auth_fetch, ) - .map_err(|e| { - error!("Auth check failed: {:?}", e); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed.") - })?; + .await + .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Event is not authorized.")); + return Err!(Request(Forbidden("Event is not authorized."))); } // Hash and sign @@ -762,21 +793,15 @@ pub fn create_hash_and_sign_event( to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"), ); - match ruma::signatures::hash_and_sign_event( - self.services.globals.server_name().as_str(), - self.services.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(()) => {}, - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => { - Err(Error::BadRequest(ErrorKind::TooLarge, "Message is too long")) - }, - _ => Err(Error::BadRequest(ErrorKind::Unknown, "Signing event failed")), - } - }, + if let Err(e) = self + .services + .server_keys + .hash_and_sign_event(&mut pdu_json, &room_version_id) + { + return match e { + Error::Signatures(ruma::signatures::Error::PduSize) => Err!(Request(TooLarge("Message is too long"))), + _ => Err!(Request(Unknown("Signing event failed"))), + }; } // Generate event id @@ -795,7 +820,8 @@ pub fn create_hash_and_sign_event( let _shorteventid = self .services .short - .get_or_create_shorteventid(&pdu.event_id)?; + .get_or_create_shorteventid(&pdu.event_id) + .await; Ok((pdu, pdu_json)) } @@ -811,108 +837,40 @@ pub async fn build_and_append_pdu( room_id: &RoomId, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Arc<EventId>> { - let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = self.services.admin.get_admin_room()? { - if admin_room == room_id { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Encryption is not allowed in the admins room", - )); - }, - TimelineEventType::RoomMember => { - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_user = &self.services.globals.server_user.to_string(); - - let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu"))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Server user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot leave from admins room.", - )); - } - - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot leave from admins room.", - )); - } - } - - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Server user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Server user cannot be banned in admins room.", - )); - } + let (pdu, pdu_json) = self + .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) + .await?; - let count = self - .services - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Last admin cannot be banned in admins room.", - )); - } - } - }, - _ => {}, - } - } + if self.services.admin.is_admin_room(&pdu.room_id).await { + self.check_pdu_for_admin_room(&pdu, sender).boxed().await?; } // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match self.services.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id).await? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } }; }, _ => { - let content = serde_json::from_str::<RoomRedactionEventContent>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; - + let content: RoomRedactionEventContent = pdu.get_content()?; if let Some(redact_id) = &content.redacts { if !self .services .state_accessor - .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false) + .await? { - return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); + return Err!(Request(Forbidden("User cannot redact this event."))); } } }, @@ -922,7 +880,7 @@ pub async fn build_and_append_pdu( // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = self.services.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu).await?; let pdu_id = self .append_pdu( @@ -933,20 +891,22 @@ pub async fn build_and_append_pdu( vec![(*pdu.event_id).to_owned()], state_lock, ) + .boxed() .await?; // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist self.services .state - .set_room_state(room_id, statehashid, state_lock)?; + .set_room_state(&pdu.room_id, statehashid, state_lock); let mut servers: HashSet<OwnedServerName> = self .services .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .collect(); + .room_servers(&pdu.room_id) + .map(ToOwned::to_owned) + .collect() + .await; // In case we are kicking or banning a user, we need to inform their server of // the change @@ -966,7 +926,8 @@ pub async fn build_and_append_pdu( self.services .sending - .send_pdu_servers(servers.into_iter(), &pdu_id)?; + .send_pdu_servers(servers.iter().map(AsRef::as_ref).stream(), &pdu_id) + .await?; Ok(pdu.event_id) } @@ -982,21 +943,25 @@ pub async fn append_incoming_pdu( state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, soft_fail: bool, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result<Option<Vec<u8>>> { + ) -> Result<Option<RawPduId>> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. self.services .state - .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed) + .await?; if soft_fail { self.services .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + .mark_as_referenced(&pdu.room_id, &pdu.prev_events); + self.services .state - .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) + .await; + return Ok(None); } @@ -1009,71 +974,88 @@ pub async fn append_incoming_pdu( /// Returns an iterator over all PDUs in a room. #[inline] - pub fn all_pdus<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()) + pub async fn all_pdus<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, + ) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> { + self.pdus(Some(user_id), room_id, None).await } - /// Returns an iterator over all events and their tokens in a room that - /// happened before the event with id `until` in reverse-chronological - /// order. + /// Reverse iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { - self.db.pdus_until(user_id, room_id, until) + pub async fn pdus_rev<'a>( + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, until: Option<PduCount>, + ) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> { + self.db + .pdus_rev(user_id, room_id, until.unwrap_or_else(PduCount::max)) + .await } - /// Returns an iterator over all events and their token in a room that - /// happened after the event with id `from` in chronological order. + /// Forward iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { - self.db.pdus_after(user_id, room_id, from) + pub async fn pdus<'a>( + &'a self, user_id: Option<&'a UserId>, room_id: &'a RoomId, from: Option<PduCount>, + ) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> { + self.db + .pdus(user_id, room_id, from.unwrap_or_else(PduCount::min)) + .await } /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result { // TODO: Don't reserialize, keep original json - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let Ok(pdu_id) = self.get_pdu_id(event_id).await else { + // If event does not exist, just noop + return Ok(()); + }; - if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) { - if let Some(body) = content.body { - self.services - .search - .deindex_pdu(shortroomid, &pdu_id, &body)?; - } + let mut pdu = self + .get_pdu_from_id(&pdu_id) + .await + .map_err(|e| err!(Database(error!(?pdu_id, ?event_id, ?e, "PDU ID points to invalid PDU."))))?; + + if let Ok(content) = pdu.get_content::<ExtractBody>() { + if let Some(body) = content.body { + self.services + .search + .deindex_pdu(shortroomid, &pdu_id, &body); } + } - let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?; - pdu.redact(room_version_id, reason)?; + pdu.redact(room_version_id, reason)?; - self.replace_pdu( - &pdu_id, - &utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?, - &pdu, - )?; - } - // If event does not exist, just noop - Ok(()) + let obj = utils::to_canonical_object(&pdu) + .map_err(|e| err!(Database(error!(?event_id, ?e, "Failed to convert PDU to canonical JSON"))))?; + + self.replace_pdu(&pdu_id, &obj, &pdu).await } #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { + if self + .services + .state_cache + .room_joined_count(room_id) + .await + .is_ok_and(|count| count <= 1) + && !self + .services + .state_accessor + .is_world_readable(room_id) + .await + { + // Room is empty (1 user or none), there is no one that can backfill + return Ok(()); + } + let first_pdu = self - .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id) + .await? .next() - .expect("Room is not empty")?; + .await + .expect("Room is not empty"); if first_pdu.0 < from { // No backfill required, there are still events between them @@ -1083,54 +1065,58 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re let power_levels: RoomPowerLevelsEventContent = self .services .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? + .room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "") + .await .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { - Some(user_id.server_name().to_owned()) + Some(user_id.server_name()) } else { None } }); - let room_alias_servers = self - .services - .alias - .local_aliases_for_room(room_id) - .filter_map(|alias| { - alias - .ok() - .filter(|alias| !self.services.globals.server_is_ours(alias.server_name())) - .map(|alias| alias.server_name().to_owned()) - }); - - let servers = room_mods - .chain(room_alias_servers) - .chain(self.services.server.config.trusted_servers.clone()) - .filter(|server_name| { - if self.services.globals.server_is_ours(server_name) { - return false; - } - + let canonical_room_alias_server = once( + self.services + .state_accessor + .get_canonical_alias(room_id) + .await, + ) + .filter_map(Result::ok) + .map(|alias| alias.server_name().to_owned()) + .stream(); + + let mut servers = room_mods + .stream() + .map(ToOwned::to_owned) + .chain(canonical_room_alias_server) + .chain( + self.services + .server + .config + .trusted_servers + .iter() + .map(ToOwned::to_owned) + .stream(), + ) + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)) + .filter_map(|server_name| async move { self.services .state_cache - .server_in_room(server_name, room_id) - .unwrap_or(false) - }); + .server_in_room(&server_name, room_id) + .await + .then_some(server_name) + }) + .boxed(); - for backfill_server in servers { + while let Some(ref backfill_server) = servers.next().await { info!("Asking {backfill_server} for backfill"); let response = self .services .sending .send_federation_request( - &backfill_server, + backfill_server, federation::backfill::get_backfill::v1::Request { room_id: room_id.to_owned(), v: vec![first_pdu.1.event_id.as_ref().to_owned()], @@ -1140,9 +1126,8 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re .await; match response { Ok(response) => { - let pub_key_map = RwLock::new(BTreeMap::new()); for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(&backfill_server, pdu, &pub_key_map).await { + if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await { warn!("Failed to add backfilled pdu in room {room_id}: {e}"); } } @@ -1158,12 +1143,9 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re Ok(()) } - #[tracing::instrument(skip(self, pdu, pub_key_map))] - pub async fn backfill_pdu( - &self, origin: &ServerName, pdu: Box<RawJsonValue>, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> { - let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; + #[tracing::instrument(skip(self, pdu))] + pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> { + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = self @@ -1174,52 +1156,48 @@ pub async fn backfill_pdu( .await; // Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = self.get_pdu_id(&event_id)? { - let pdu_id = pdu_id.to_vec(); + if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); } - self.services - .server_keys - .fetch_required_signing_keys([&value], pub_key_map) - .await?; - self.services .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) + .handle_incoming_pdu(origin, &room_id, &event_id, value, false) .await?; - let value = self.get_pdu_json(&event_id)?.expect("We just created it"); - let pdu = self.get_pdu(&event_id)?.expect("We just created it"); + let value = self + .get_pdu_json(&event_id) + .await + .expect("We just created it"); + let pdu = self.get_pdu(&event_id).await.expect("We just created it"); let shortroomid = self .services .short - .get_shortroomid(&room_id)? + .get_shortroomid(&room_id) + .await .expect("room exists"); let insert_lock = self.mutex_insert.lock(&room_id).await; - let max = u64::MAX; - let count = self.services.globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); + let count: i64 = self.services.globals.next_count().unwrap().try_into()?; + + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::Backfilled(validated!(0 - count)), + } + .into(); // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); drop(insert_lock); if pdu.kind == TimelineEventType::RoomMessage { - let content = serde_json::from_str::<ExtractBody>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - + let content: ExtractBody = pdu.get_content()?; if let Some(body) = content.body { - self.services - .search - .index_pdu(shortroomid, &pdu_id, &body)?; + self.services.search.index_pdu(shortroomid, &pdu_id, &body); } } drop(mutex_lock); @@ -1229,15 +1207,67 @@ pub async fn backfill_pdu( } } -#[cfg(test)] -mod tests { - use super::*; +#[implement(Service)] +#[tracing::instrument(skip_all, level = "debug")] +async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Result<()> { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + return Err!(Request(Forbidden(error!("Encryption not supported in admins room.")))); + }, + TimelineEventType::RoomMember => { + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); + + let server_user = &self.services.globals.server_user.to_string(); + + let content: RoomMemberEventContent = pdu.get_content()?; + match content.membership { + MembershipState::Leave => { + if target == server_user { + return Err!(Request(Forbidden(error!("Server user cannot leave the admins room.")))); + } - #[test] - fn comparisons() { - assert!(PduCount::Normal(1) < PduCount::Normal(2)); - assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); - assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); - assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); - } + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + return Err!(Request(Forbidden(error!("Last admin cannot leave the admins room.")))); + } + }, + + MembershipState::Ban if pdu.state_key().is_some() => { + if target == server_user { + return Err!(Request(Forbidden(error!("Server cannot be banned from admins room.")))); + } + + let count = self + .services + .state_cache + .room_members(&pdu.room_id) + .ready_filter(|user| self.services.globals.user_is_local(user)) + .ready_filter(|user| *user != target) + .boxed() + .count() + .await; + + if count < 2 { + return Err!(Request(Forbidden(error!("Last admin cannot be banned from admins room.")))); + } + }, + _ => {}, + }; + }, + _ => {}, + }; + + Ok(()) } diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 3cf1cdd5939ecdfde3e37b196223b77b666c7160..8ee34f44d7636c32bb0d77377f753fd158ed0fee 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,6 +1,11 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug_info, trace, utils, Result, Server}; +use conduit::{ + debug_info, trace, + utils::{self, IterStream}, + Result, Server, +}; +use futures::StreamExt; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -8,7 +13,7 @@ }; use tokio::sync::{broadcast, RwLock}; -use crate::{globals, sending, Dep}; +use crate::{globals, sending, users, Dep}; pub struct Service { server: Arc<Server>, @@ -23,6 +28,7 @@ pub struct Service { struct Services { globals: Dep<globals::Service>, sending: Dep<sending::Service>, + users: Dep<users::Service>, } impl crate::Service for Service { @@ -32,6 +38,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { services: Services { globals: args.depend::<globals::Service>("globals"), sending: args.depend::<sending::Service>("sending"), + users: args.depend::<users::Service>("users"), }, typing: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()), @@ -46,7 +53,7 @@ impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout); + debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); // update clients self.typing .write() @@ -54,17 +61,19 @@ pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) .entry(room_id.to_owned()) .or_default() .insert(user_id.to_owned(), timeout); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true).await?; } Ok(()) @@ -72,7 +81,7 @@ pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - debug_info!("typing stopped {:?} in {:?}", user_id, room_id); + debug_info!("typing stopped {user_id:?} in {room_id:?}"); // update clients self.typing .write() @@ -80,31 +89,31 @@ pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result< .entry(room_id.to_owned()) .or_default() .remove(user_id); + self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if self.services.globals.user_is_local(user_id) { - self.federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false).await?; } Ok(()) } - pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> { + pub async fn wait_for_update(&self, room_id: &RoomId) { let mut receiver = self.typing_update_sender.subscribe(); while let Ok(next) = receiver.recv().await { if next == room_id { break; } } - - Ok(()) } /// Makes sure that typing events with old timestamps get removed. @@ -123,30 +132,30 @@ async fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { removable.push(user.clone()); } } - - drop(typing); }; if !removable.is_empty() { let typing = &mut self.typing.write().await; let room = typing.entry(room_id.to_owned()).or_default(); for user in &removable { - debug_info!("typing timeout {:?} in {:?}", &user, room_id); + debug_info!("typing timeout {user:?} in {room_id:?}"); room.remove(user); } + // update clients self.last_typing_update .write() .await .insert(room_id.to_owned(), self.services.globals.next_count()?); + if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation - for user in removable { - if self.services.globals.user_is_local(&user) { - self.federation_send(room_id, &user, false)?; + for user in &removable { + if self.services.globals.user_is_local(user) { + self.federation_send(room_id, user, false).await?; } } } @@ -168,22 +177,40 @@ pub async fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { /// Returns a new typing EDU. pub async fn typings_all( - &self, room_id: &RoomId, + &self, room_id: &RoomId, sender_user: &UserId, ) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> { + let room_typing_indicators = self.typing.read().await.get(room_id).cloned(); + + let Some(typing_indicators) = room_typing_indicators else { + return Ok(SyncEphemeralRoomEvent { + content: ruma::events::typing::TypingEventContent { + user_ids: Vec::new(), + }, + }); + }; + + let user_ids: Vec<_> = typing_indicators + .into_keys() + .stream() + .filter_map(|typing_user_id| async move { + (!self + .services + .users + .user_is_ignored(&typing_user_id, sender_user) + .await) + .then_some(typing_user_id) + }) + .collect() + .await; + Ok(SyncEphemeralRoomEvent { content: ruma::events::typing::TypingEventContent { - user_ids: self - .typing - .read() - .await - .get(room_id) - .map(|m| m.keys().cloned().collect()) - .unwrap_or_default(), + user_ids, }, }) } - fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!( self.services.globals.user_is_local(user_id), "tried to broadcast typing status of remote user", @@ -197,7 +224,8 @@ fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> R self.services .sending - .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; + .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing")) + .await?; Ok(()) } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs deleted file mode 100644 index c71316153acb670156b448c160477551a324ff9f..0000000000000000000000000000000000000000 --- a/src/service/rooms/user/data.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::sync::Arc; - -use conduit::{utils, Error, Result}; -use database::Map; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; - -use crate::{globals, rooms, Dep}; - -pub(super) struct Data { - userroomid_notificationcount: Arc<Map>, - userroomid_highlightcount: Arc<Map>, - roomuserid_lastnotificationread: Arc<Map>, - roomsynctoken_shortstatehash: Arc<Map>, - userroomid_joined: Arc<Map>, - services: Services, -} - -struct Services { - globals: Dep<globals::Service>, - short: Dep<rooms::short::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), - userroomid_highlightcount: db["userroomid_highlightcount"].clone(), - roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit - roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), - userroomid_joined: db["userroomid_joined"].clone(), - services: Services { - globals: args.depend::<globals::Service>("globals"), - short: args.depend::<rooms::short::Service>("rooms::short"), - }, - } - } - - pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - - self.roomuserid_lastnotificationread - .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - } - - pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xFF); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - } - - pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) - }) - .transpose()? - .unwrap_or(0)) - } - - pub(super) fn associate_token_shortstatehash( - &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { - let shortroomid = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) - } - - pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { - let shortroomid = self - .services - .short - .get_shortroomid(room_id)? - .expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) - }) - .transpose() - } - - pub(super) fn get_shared_rooms<'a>( - &'a self, users: Vec<OwnedUserId>, - ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xFF) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - .saturating_add(1); // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(Result::ok) - }); - - // We use the default compare function because keys are sorted correctly (not - // reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) - } -} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 93d38470f508b0ec4c7d8e7e6cd3fedd20ad83ff..995871342377e66ee2c80513900bc51c896b7fc7 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,53 +1,139 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use conduit::{implement, Result}; +use database::{Deserialized, Map}; +use futures::{pin_mut, Stream, StreamExt}; +use ruma::{RoomId, UserId}; -use self::data::Data; +use crate::{globals, rooms, rooms::short::ShortStateHash, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Data { + userroomid_notificationcount: Arc<Map>, + userroomid_highlightcount: Arc<Map>, + roomuserid_lastnotificationread: Arc<Map>, + roomsynctoken_shortstatehash: Arc<Map>, +} + +struct Services { + globals: Dep<globals::Service>, + short: Dep<rooms::short::Service>, + state_cache: Dep<rooms::state_cache::Service>, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(&args), + db: Data { + userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), + userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), + roomuserid_lastnotificationread: args.db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit + roomsynctoken_shortstatehash: args.db["roomsynctoken_shortstatehash"].clone(), + }, + + services: Services { + globals: args.depend::<globals::Service>("globals"), + short: args.depend::<rooms::short::Service>("rooms::short"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) - } +#[implement(Service)] +pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + let userroom_id = (user_id, room_id); + self.db.userroomid_highlightcount.put(userroom_id, 0_u64); + self.db.userroomid_notificationcount.put(userroom_id, 0_u64); - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - self.db.notification_count(user_id, room_id) - } + let roomuser_id = (room_id, user_id); + let count = self.services.globals.next_count().unwrap(); + self.db + .roomuserid_lastnotificationread + .put(roomuser_id, count); +} - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - self.db.highlight_count(user_id, room_id) - } +#[implement(Service)] +pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); + self.db + .userroomid_notificationcount + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - self.db.last_notification_read(user_id, room_id) - } +#[implement(Service)] +pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (user_id, room_id); + self.db + .userroomid_highlightcount + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { - self.db - .associate_token_shortstatehash(room_id, token, shortstatehash) - } +#[implement(Service)] +pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { + let key = (room_id, user_id); + self.db + .roomuserid_lastnotificationread + .qry(&key) + .await + .deserialized() + .unwrap_or(0) +} - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { - self.db.get_token_shortstatehash(room_id, token) - } +#[implement(Service)] +pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) { + let shortroomid = self + .services + .short + .get_shortroomid(room_id) + .await + .expect("room exists"); - pub fn get_shared_rooms(&self, users: Vec<OwnedUserId>) -> Result<impl Iterator<Item = Result<OwnedRoomId>> + '_> { - self.db.get_shared_rooms(users) - } + let key: &[u64] = &[shortroomid, token]; + self.db + .roomsynctoken_shortstatehash + .put(key, shortstatehash); +} + +#[implement(Service)] +pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<ShortStateHash> { + let shortroomid = self.services.short.get_shortroomid(room_id).await?; + + let key: &[u64] = &[shortroomid, token]; + self.db + .roomsynctoken_shortstatehash + .qry(key) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool { + let get_shared_rooms = self.get_shared_rooms(user_a, user_b); + + pin_mut!(get_shared_rooms); + get_shared_rooms.next().await.is_some() +} + +//TODO: optimize; replace point-queries with dual iteration +#[implement(Service)] +pub fn get_shared_rooms<'a>( + &'a self, user_a: &'a UserId, user_b: &'a UserId, +) -> impl Stream<Item = &RoomId> + Send + 'a { + self.services + .state_cache + .rooms_joined(user_a) + .filter(|room_id| self.services.state_cache.is_joined(user_b, room_id)) } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 6c8e2544def207f21b924bd1753641b991a5c029..cd25776a5c456a8b4ae91b378b457f0862704f21 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,14 +1,21 @@ use std::sync::Arc; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{ + utils, + utils::{stream::TryIgnore, ReadyExt}, + Error, Result, +}; +use database::{Database, Deserialized, Map}; +use futures::{Stream, StreamExt}; use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; use crate::{globals, Dep}; -type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>; -type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>; +pub(super) type OutgoingItem = (Key, SendingEvent, Destination); +pub(super) type SendingItem = (Key, SendingEvent); +pub(super) type QueueItem = (Key, SendingEvent); +pub(super) type Key = Vec<u8>; pub struct Data { servercurrentevent_data: Arc<Map>, @@ -36,58 +43,82 @@ pub(super) fn new(args: &crate::Args<'_>) -> Self { } } - #[inline] - pub fn active_requests(&self) -> OutgoingSendingIter<'_> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) + pub(super) fn delete_active_request(&self, key: &[u8]) { self.servercurrentevent_data.remove(key); } + + pub(super) async fn delete_all_active_requests_for(&self, destination: &Destination) { + let prefix = destination.get_prefix(); + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; } - #[inline] - pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { + pub(super) async fn delete_all_requests_for(&self, destination: &Destination) { let prefix = destination.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) + self.servercurrentevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servercurrentevent_data.remove(key)) + .await; + + self.servernameevent_data + .raw_keys_prefix(&prefix) + .ignore_err() + .ready_for_each(|key| self.servernameevent_data.remove(key)) + .await; } - pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) } + pub(super) fn mark_as_active(&self, events: &[QueueItem]) { + for (key, e) in events { + if key.is_empty() { + continue; + } - pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { - let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; + let value = if let SendingEvent::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value); + self.servernameevent_data.remove(key); } + } - Ok(()) + #[inline] + pub fn active_requests(&self) -> impl Stream<Item = OutgoingItem> + Send + '_ { + self.servercurrentevent_data + .raw_stream() + .ignore_err() + .map(|(key, val)| { + let (dest, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); + + (key.to_vec(), event, dest) + }) } - pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { + #[inline] + pub fn active_requests_for(&self, destination: &Destination) -> impl Stream<Item = SendingItem> + Send + '_ { let prefix = destination.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } + self.servercurrentevent_data + .stream_prefix_raw(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } - - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result<Vec<Vec<u8>>> { + pub(super) fn queue_requests(&self, requests: &[(&SendingEvent, &Destination)]) -> Vec<Vec<u8>> { let mut batch = Vec::new(); let mut keys = Vec::new(); - for (destination, event) in requests { + for (event, destination) in requests { let mut key = destination.get_prefix(); - if let SendingEvent::Pdu(value) = &event { - key.extend_from_slice(value); + if let SendingEvent::Pdu(value) = event { + key.extend(value.as_ref()); } else { - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + key.extend(&self.services.globals.next_count().unwrap().to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value @@ -97,56 +128,38 @@ pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) - batch.push((key.clone(), value.to_owned())); keys.push(key); } - self.servernameevent_data - .insert_batch(batch.iter().map(database::KeyVal::from))?; - Ok(keys) - } - pub fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box<dyn Iterator<Item = Result<(SendingEvent, Vec<u8>)>> + 'a> { - let prefix = destination.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); + self.servernameevent_data.insert_batch(batch.iter()); + keys } - pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec<u8>)]) -> Result<()> { - for (e, key) in events { - if key.is_empty() { - continue; - } - - let value = if let SendingEvent::Edu(value) = &e { - &**value - } else { - &[] - }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; - } + pub fn queued_requests(&self, destination: &Destination) -> impl Stream<Item = QueueItem> + Send + '_ { + let prefix = destination.get_prefix(); + self.servernameevent_data + .stream_prefix_raw(&prefix) + .ignore_err() + .map(|(key, val)| { + let (_, event) = parse_servercurrentevent(key, val).expect("invalid servercurrentevent"); - Ok(()) + (key.to_vec(), event) + }) } - pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { - self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) { + self.servername_educount.raw_put(server_name, last_count); } - pub fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { + pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 { self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) + .get(server_name) + .await + .deserialized() + .unwrap_or(0) } } #[tracing::instrument(skip(key), level = "debug")] -fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> { +fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { let mut parts = key[1..].splitn(2, |&b| b == 0xFF); @@ -162,9 +175,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, ( Destination::Appservice(server), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else if key.starts_with(b"$") { @@ -189,10 +202,10 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, ( Destination::Push(user_id, pushkey_string), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { // I'm pretty sure this should never be called - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) } else { @@ -212,9 +225,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, ), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { - SendingEvent::Edu(value) + SendingEvent::Edu(value.to_vec()) }, ) }) diff --git a/src/service/sending/dest.rs b/src/service/sending/dest.rs index 9968acd766e1240615bb1410b27b7278b752cb51..234a0b906ce631b888f93f76c97ef66685b3bde0 100644 --- a/src/service/sending/dest.rs +++ b/src/service/sending/dest.rs @@ -12,7 +12,7 @@ pub enum Destination { #[implement(Destination)] #[must_use] -pub fn get_prefix(&self) -> Vec<u8> { +pub(super) fn get_prefix(&self) -> Vec<u8> { match self { Self::Normal(server) => { let len = server.as_bytes().len().saturating_add(1); diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index b90ea361846df85ba6baed221fde70a678cb26cc..77997f6976b4fa09f7a01e56bfcc79a14095299d 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,16 +7,27 @@ use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, warn, Result, Server}; +use conduit::{ + err, + utils::{ReadyExt, TryReadyExt}, + warn, Result, Server, +}; +use futures::{Stream, StreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, - OwnedServerName, RoomId, ServerName, UserId, + RoomId, ServerName, UserId, }; use tokio::sync::Mutex; use self::data::Data; -pub use self::dest::Destination; -use crate::{account_data, client, globals, presence, pusher, resolver, rooms, users, Dep}; +pub use self::{ + dest::Destination, + sender::{EDU_LIMIT, PDU_LIMIT}, +}; +use crate::{ + account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users, + Dep, +}; pub struct Service { server: Arc<Server>, @@ -40,6 +51,7 @@ struct Services { account_data: Dep<account_data::Service>, appservice: Dep<crate::appservice::Service>, pusher: Dep<pusher::Service>, + server_keys: Dep<server_keys::Service>, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -52,9 +64,9 @@ struct Msg { #[allow(clippy::module_name_repetitions)] #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SendingEvent { - Pdu(Vec<u8>), // pduid - Edu(Vec<u8>), // pdu json - Flush, // none + Pdu(RawPduId), // pduid + Edu(Vec<u8>), // pdu json + Flush, // none } #[async_trait] @@ -77,6 +89,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { account_data: args.depend::<account_data::Service>("account_data"), appservice: args.depend::<crate::appservice::Service>("appservice"), pusher: args.depend::<pusher::Service>("pusher"), + server_keys: args.depend::<server_keys::Service>("server_keys"), }, db: Data::new(&args), sender, @@ -100,11 +113,11 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] - pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result { let dest = Destination::Push(user.to_owned(), pushkey); - let event = SendingEvent::Pdu(pdu_id.to_owned()); + let event = SendingEvent::Pdu(*pdu_id); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -113,11 +126,11 @@ pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Re } #[tracing::instrument(skip(self), level = "debug")] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> { + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -126,30 +139,31 @@ pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Res } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_pdu_servers(servers, pdu_id) + self.send_pdu_servers(servers, pdu_id).await } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) - .collect::<Vec<_>>(); + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result + where + S: Stream<Item = &'a ServerName> + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::<Vec<_>>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))) + .collect::<Vec<_>>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::<Vec<_>>()); + for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { dest, @@ -166,7 +180,7 @@ pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Resul let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); - let keys = self.db.queue_requests(&[(&dest, event.clone())])?; + let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { dest, event, @@ -175,30 +189,30 @@ pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Resul } #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] - pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> { + pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.send_edu_servers(servers, serialized) + self.send_edu_servers(servers, serialized).await } #[tracing::instrument(skip(self, servers, serialized), level = "debug")] - pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, serialized: Vec<u8>) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) - .collect::<Vec<_>>(); + pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec<u8>) -> Result<()> + where + S: Stream<Item = &'a ServerName> + Send + 'a, + { let _cork = self.db.db.cork(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::<Vec<_>>(), - )?; + let requests = servers + .map(|server| (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone()))) + .collect::<Vec<_>>() + .await; + + let keys = self + .db + .queue_requests(&requests.iter().map(|(o, e)| (e, o)).collect::<Vec<_>>()); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { self.dispatch(Msg { @@ -212,31 +226,36 @@ pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, } #[tracing::instrument(skip(self, room_id), level = "debug")] - pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { + pub async fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = self .services .state_cache .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| !self.services.globals.server_is_ours(server_name)); + .ready_filter(|server_name| !self.services.globals.server_is_ours(server_name)); - self.flush_servers(servers) + self.flush_servers(servers).await } #[tracing::instrument(skip(self, servers), level = "debug")] - pub fn flush_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I) -> Result<()> { - let requests = servers.into_iter().map(Destination::Normal); - for dest in requests { - self.dispatch(Msg { - dest, - event: SendingEvent::Flush, - queue_id: Vec::<u8>::new(), - })?; - } - - Ok(()) + pub async fn flush_servers<'a, S>(&self, servers: S) -> Result<()> + where + S: Stream<Item = &'a ServerName> + Send + 'a, + { + servers + .map(ToOwned::to_owned) + .map(Destination::Normal) + .map(Ok) + .ready_try_for_each(|dest| { + self.dispatch(Msg { + dest, + event: SendingEvent::Flush, + queue_id: Vec::<u8>::new(), + }) + }) + .await } + /// Sends a request to a federation server #[tracing::instrument(skip_all, name = "request")] pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse> where @@ -246,6 +265,16 @@ pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> self.send(client, dest, request).await } + /// Like send_federation_request() but with a very large timeout + #[tracing::instrument(skip_all, name = "synapse")] + pub async fn send_synapse_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse> + where + T: OutgoingRequest + Debug + Send, + { + let client = &self.services.client.synapse; + self.send(client, dest, request).await + } + /// Sends a request to an appservice /// /// Only returns None if there is no url specified in the appservice @@ -263,11 +292,10 @@ pub async fn send_appservice_request<T>( /// Cleanup event data /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self), level = "debug")] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + pub async fn cleanup_events(&self, appservice_id: String) { self.db - .delete_all_requests_for(&Destination::Appservice(appservice_id))?; - - Ok(()) + .delete_all_requests_for(&Destination::Appservice(appservice_id)) + .await; } fn dispatch(&self, msg: Msg) -> Result<()> { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 9a8f408b5456cb1bb8578748e7972766cf8b13cc..6a8f1b1bdc3612e6c103f714bfe63fb520c6c37b 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,7 +1,8 @@ -use std::{fmt::Debug, mem}; +use std::mem; +use bytes::Bytes; use conduit::{ - debug, debug_error, debug_info, debug_warn, err, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, + debug, debug_error, debug_warn, err, error::inspect_debug_log, implement, trace, utils::string::EMPTY, Err, Error, Result, }; use http::{header::AUTHORIZATION, HeaderValue}; @@ -14,19 +15,19 @@ }, serde::Base64, server_util::authorization::XMatrix, - ServerName, + CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId, }; use crate::{ - globals, resolver, + resolver, resolver::{actual::ActualDest, cache::CachedDest}, }; impl super::Service { - #[tracing::instrument(skip(self, client, req), name = "send")] - pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> + #[tracing::instrument(skip(self, client, request), name = "send")] + pub async fn send<T>(&self, client: &Client, dest: &ServerName, request: T) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { if !self.server.config.allow_federation { return Err!(Config("allow_federation", "Federation is disabled.")); @@ -36,14 +37,14 @@ pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Resul .server .config .forbidden_remote_server_names - .contains(&dest.to_owned()) + .contains(dest) { - debug_info!("Refusing to send outbound federation request to {dest}"); - return Err!(Request(Forbidden("Federation with this homeserver is not allowed."))); + return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed.")))); } let actual = self.services.resolver.get_actual_dest(dest).await?; - let request = self.prepare::<T>(dest, &actual, req).await?; + let request = into_http_request::<T>(&actual, request)?; + let request = self.prepare(dest, request)?; self.execute::<T>(dest, &actual, request, client).await } @@ -51,7 +52,7 @@ async fn execute<T>( &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client, ) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { let url = request.url().clone(); let method = request.method().clone(); @@ -59,25 +60,14 @@ async fn execute<T>( debug!(?method, ?url, "Sending request"); match client.execute(request).await { Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await, - Err(error) => handle_error::<T>(dest, actual, &method, &url, error), + Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")), } } - async fn prepare<T>(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request> - where - T: OutgoingRequest + Debug + Send, - { - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; - const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); - - trace!("Preparing request"); - let mut http_request = req - .try_into_http_request::<Vec<u8>>(&actual.string, SATIR, &VERSIONS) - .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; - - sign_request::<T>(&self.services.globals, dest, &mut http_request); + fn prepare(&self, dest: &ServerName, mut request: http::Request<Vec<u8>>) -> Result<Request> { + self.sign_request(&mut request, dest); - let request = Request::try_from(http_request)?; + let request = Request::try_from(request)?; self.validate_url(request.url())?; Ok(request) @@ -97,23 +87,44 @@ fn validate_url(&self, url: &Url) -> Result<()> { async fn handle_response<T>( resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, - mut response: Response, + response: Response, ) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug + Send, + T: OutgoingRequest + Send, { + let response = into_http_response(dest, actual, method, url, response).await?; + let result = T::IncomingResponse::try_from_http_response(response); + + if result.is_ok() && !actual.cached { + resolver.set_cached_destination( + dest.to_owned(), + CachedDest { + dest: actual.dest.clone(), + host: actual.host.clone(), + expire: CachedDest::default_expire(), + }, + ); + } + + result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) +} + +async fn into_http_response( + dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, +) -> Result<http::Response<Bytes>> { let status = response.status(); trace!( ?status, ?method, request_url = ?url, response_url = ?response.url(), "Received response from {}", - actual.string, + actual.string(), ); let mut http_response_builder = http::Response::builder() .status(status) .version(response.version()); + mem::swap( response.headers_mut(), http_response_builder @@ -138,27 +149,10 @@ async fn handle_response<T>( return Err(Error::Federation(dest.to_owned(), RumaError::from_http_response(http_response))); } - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && !actual.cached { - resolver.set_cached_destination( - dest.to_owned(), - CachedDest { - dest: actual.dest.clone(), - host: actual.host.clone(), - expire: CachedDest::default_expire(), - }, - ); - } - - response.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) + Ok(http_response) } -fn handle_error<T>( - _dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error, -) -> Result<T::IncomingResponse> -where - T: OutgoingRequest + Debug + Send, -{ +fn handle_error(actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error) -> Result { if e.is_timeout() || e.is_connect() { e = e.without_url(); debug_warn!("{e:?}"); @@ -178,61 +172,86 @@ fn handle_error<T>( Err(e.into()) } -fn sign_request<T>(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) -where - T: OutgoingRequest + Debug + Send, -{ - let mut req_map = serde_json::Map::with_capacity(8); - if !http_request.body().is_empty() { - req_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), - ); - }; +#[implement(super::Service)] +fn sign_request(&self, http_request: &mut http::Request<Vec<u8>>, dest: &ServerName) { + type Member = (String, Value); + type Value = CanonicalJsonValue; + type Object = CanonicalJsonObject; - req_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - req_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - req_map.insert("origin".to_owned(), globals.server_name().as_str().into()); - req_map.insert("destination".to_owned(), dest.as_str().into()); + let origin = self.services.globals.server_name(); + let body = http_request.body(); + let uri = http_request + .uri() + .path_and_query() + .expect("http::Request missing path_and_query"); + + let mut req: Object = if !body.is_empty() { + let content: CanonicalJsonValue = serde_json::from_slice(body).expect("failed to serialize body"); + + let authorization: [Member; 5] = [ + ("content".into(), content), + ("destination".into(), dest.as_str().into()), + ("method".into(), http_request.method().as_str().into()), + ("origin".into(), origin.as_str().into()), + ("uri".into(), uri.to_string().into()), + ]; + + authorization.into() + } else { + let authorization: [Member; 4] = [ + ("destination".into(), dest.as_str().into()), + ("method".into(), http_request.method().as_str().into()), + ("origin".into(), origin.as_str().into()), + ("uri".into(), uri.to_string().into()), + ]; - let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); - ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json) - .expect("our request json is what ruma expects"); + authorization.into() + }; - let req_json: serde_json::Map<String, serde_json::Value> = - serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); + self.services + .server_keys + .sign_json(&mut req) + .expect("request signing failed"); - let signatures = req_json["signatures"] + let signatures = req["signatures"] .as_object() - .expect("signatures object") + .and_then(|object| object[origin.as_str()].as_object()) + .expect("origin signatures object"); + + let key: &ServerSigningKeyId = signatures + .keys() + .next() + .map(|k| k.as_str().try_into()) + .expect("at least one signature from this origin") + .expect("keyid is json string"); + + let sig: Base64 = signatures .values() - .map(|v| { - v.as_object() - .expect("server signatures object") - .iter() - .map(|(k, v)| (k, v.as_str().expect("server signature string"))) - }); - - for signature_server in signatures { - for s in signature_server { - let key = - s.0.as_str() - .try_into() - .expect("valid homeserver signing key ID"); - let sig = Base64::parse(s.1).expect("valid base64"); - - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)), - ); - } - } + .next() + .map(|s| s.as_str().map(Base64::parse)) + .expect("at least one signature from this origin") + .expect("signature is json string") + .expect("signature is valid base64"); + + let x_matrix = XMatrix::new(origin.into(), dest.into(), key.into(), sig); + let authorization = HeaderValue::from(&x_matrix); + let authorization = http_request + .headers_mut() + .insert(AUTHORIZATION, authorization); + + debug_assert!(authorization.is_none(), "Authorization header already present"); +} + +fn into_http_request<T>(actual: &ActualDest, request: T) -> Result<http::Request<Vec<u8>>> +where + T: OutgoingRequest + Send, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); + + let http_request = request + .try_into_http_request::<Vec<u8>>(actual.string().as_str(), SATIR, &VERSIONS) + .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; + + Ok(http_request) } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 206bf92bbcae7ee6a9cb5fca59105647ef841dae..ee8182895ab9e3c3fd473d01bc595df13788486d 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -7,28 +7,32 @@ use base64::{engine::general_purpose, Engine as _}; use conduit::{ - debug, debug_warn, error, trace, - utils::{calculate_hash, math::continue_exponential_backoff_secs}, + debug, debug_warn, err, error, + result::LogErr, + trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs, ReadyExt}, warn, Error, Result, }; -use federation::transactions::send_transaction_message; -use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; use ruma::{ - api::federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, + api::{ + appservice::event::push_events::v1::Edu as RumaEdu, + federation::transactions::{ + edu::{ + DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, + }, + send_transaction_message, }, }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, - push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, - ServerName, UInt, + push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, + RoomVersionId, ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service}; #[derive(Debug)] enum TransactionStatus { @@ -43,27 +47,32 @@ enum TransactionStatus { type SendingFutures<'a> = FuturesUnordered<SendingFuture<'a>>; type CurTransactionStatus = HashMap<Destination, TransactionStatus>; -const DEQUEUE_LIMIT: usize = 48; -const SELECT_EDU_LIMIT: usize = 16; const CLEANUP_TIMEOUT_MS: u64 = 3500; +const SELECT_PRESENCE_LIMIT: usize = 256; +const SELECT_RECEIPT_LIMIT: usize = 256; +const SELECT_EDU_LIMIT: usize = EDU_LIMIT - 2; +const DEQUEUE_LIMIT: usize = 48; + +pub const PDU_LIMIT: usize = 50; +pub const EDU_LIMIT: usize = 100; + impl Service { #[tracing::instrument(skip_all, name = "sender")] pub(super) async fn sender(&self) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); + let mut futures: SendingFutures<'_> = FuturesUnordered::new(); + let receiver = self.receiver.lock().await; - self.initial_requests(&futures, &mut statuses); - loop { - debug_assert!(!receiver.is_closed(), "channel error"); + self.initial_requests(&mut futures, &mut statuses).await; + while !receiver.is_closed() { tokio::select! { request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &futures, &mut statuses), + Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &futures, &mut statuses); + self.handle_response(response, &mut futures, &mut statuses).await; }, } } @@ -72,18 +81,16 @@ pub(super) async fn sender(&self) -> Result<()> { Ok(()) } - fn handle_response<'a>( - &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + async fn handle_response<'a>( + &'a self, response: SendingResult, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { - Ok(dest) => self.handle_response_ok(&dest, futures, statuses), - Err((dest, e)) => Self::handle_response_err(dest, futures, statuses, &e), + Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await, + Err((dest, e)) => Self::handle_response_err(dest, statuses, &e), }; } - fn handle_response_err( - dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, - ) { + fn handle_response_err(dest: Destination, statuses: &mut CurTransactionStatus, e: &Error) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { *e = match e { @@ -94,39 +101,40 @@ fn handle_response_err( }); } - fn handle_response_ok<'a>( - &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); - self.db - .delete_all_active_requests_for(dest) - .expect("all active requests deleted"); + self.db.delete_all_active_requests_for(dest).await; // Find events that have been added since starting the last request let new_events = self .db .queued_requests(dest) - .filter_map(Result::ok) .take(DEQUEUE_LIMIT) - .collect::<Vec<_>>(); + .collect::<Vec<_>>() + .await; // Insert any pdus we found if !new_events.is_empty() { - self.db - .mark_as_active(&new_events) - .expect("marked as active"); - let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); + self.db.mark_as_active(&new_events); + + let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); + futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); } else { statuses.remove(dest); } } - fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { - let iv = vec![(msg.event, msg.queue_id)]; - if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn handle_request<'a>( + &'a self, msg: Msg, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, + ) { + let iv = vec![(msg.queue_id, msg.event)]; + if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if !events.is_empty() { - futures.push(Box::pin(self.send_events(msg.dest, events))); + futures.push(self.send_events(msg.dest, events).boxed()); } else { statuses.remove(&msg.dest); } @@ -142,7 +150,7 @@ async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, status tokio::select! { () = sleep_until(deadline.into()) => break, response = futures.next() => match response { - Some(response) => self.handle_response(response, futures, statuses), + Some(response) => self.handle_response(response, futures, statuses).await, None => return, } } @@ -151,16 +159,17 @@ async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, status debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { + #[allow(clippy::needless_pass_by_ref_mut)] + async fn initial_requests<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); - for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { + let mut active = self.db.active_requests().boxed(); + + while let Some((key, event, dest)) = active.next().await { let entry = txns.entry(dest.clone()).or_default(); if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { - warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); - self.db - .delete_active_request(&key) - .expect("active request deleted"); + warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); + self.db.delete_active_request(&key); } else { entry.push(event); } @@ -169,16 +178,16 @@ fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut C for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(self.send_events(dest.clone(), events))); + futures.push(self.send_events(dest.clone(), events).boxed()); } } } #[tracing::instrument(skip_all, level = "debug")] - fn select_events( + async fn select_events( &self, dest: &Destination, - new_events: Vec<(SendingEvent, Vec<u8>)>, // Events we want to send: event and full key + new_events: Vec<QueueItem>, // Events we want to send: event and full key statuses: &mut CurTransactionStatus, ) -> Result<Option<Vec<SendingEvent>>> { let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; @@ -195,8 +204,8 @@ fn select_events( if retry { self.db .active_requests_for(dest) - .filter_map(Result::ok) - .for_each(|(_, e)| events.push(e)); + .ready_for_each(|(_, e)| events.push(e)) + .await; return Ok(Some(events)); } @@ -204,17 +213,18 @@ fn select_events( // Compose the next transaction let _cork = self.db.db.cork(); if !new_events.is_empty() { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { + self.db.mark_as_active(&new_events); + for (_, e) in new_events { events.push(e); } } // Add EDU's into the transaction if let Destination::Normal(server_name) = dest { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { + debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit"); events.extend(select_edus.into_iter().map(SendingEvent::Edu)); - self.db.set_latest_educount(server_name, last_count)?; + self.db.set_latest_educount(server_name, last_count); } } @@ -225,13 +235,15 @@ fn select_events( fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> { let (mut allow, mut retry) = (true, false); statuses - .entry(dest) + .entry(dest.clone()) // TODO: can we avoid cloning? .and_modify(|e| match e { TransactionStatus::Failed(tries, time) => { // Fail if a request has failed recently (exponential backoff) let min = self.server.config.sender_timeout; let max = self.server.config.sender_retry_backoff_limit; - if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) { + if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) + && !matches!(dest, Destination::Appservice(_)) + { allow = false; } else { retry = true; @@ -248,167 +260,230 @@ fn select_events_current(&self, dest: Destination, statuses: &mut CurTransaction } #[tracing::instrument(skip_all, level = "debug")] - fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { + async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { // u64: count of last edu - let since = self.db.get_latest_educount(server_name)?; - let mut events = Vec::new(); + let since = self.db.get_latest_educount(server_name).await; let mut max_edu_count = since; - let mut device_list_changes = HashSet::new(); - - for room_id in self.services.state_cache.server_rooms(server_name) { - let room_id = room_id?; - // Look for device list updates in this room - device_list_changes.extend( - self.services - .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(Result::ok) - .filter(|user_id| self.services.globals.user_is_local(user_id)), - ); - - if self.server.config.allow_outgoing_read_receipts - && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? - { - break; - } - } + let mut events = Vec::new(); + + self.select_edus_device_changes(server_name, since, &mut max_edu_count, &mut events) + .await; - for user_id in device_list_changes { - // Empty prev id forces synapse to resync; because synapse resyncs, - // we can just insert placeholder data - let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { - user_id, - device_id: device_id!("placeholder").to_owned(), - device_display_name: Some("Placeholder".to_owned()), - stream_id: uint!(1), - prev_id: Vec::new(), - deleted: None, - keys: None, - }); - - events.push(serde_json::to_vec(&edu).expect("json can be serialized")); + if self.server.config.allow_outgoing_read_receipts { + self.select_edus_receipts(server_name, since, &mut max_edu_count, &mut events) + .await; } if self.server.config.allow_outgoing_presence { - self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events) + .await; } Ok((events, max_edu_count)) } /// Look for presence - fn select_edus_presence( + async fn select_edus_device_changes( &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, - ) -> Result<bool> { - // Look for presence updates for this server - let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { - *max_edu_count = cmp::max(count, *max_edu_count); + ) { + debug_assert!(events.len() < SELECT_EDU_LIMIT, "called when edu limit reached"); - if !self.services.globals.user_is_local(&user_id) { - continue; - } + let server_rooms = self.services.state_cache.server_rooms(server_name); - if !self + pin_mut!(server_rooms); + let mut device_list_changes = HashSet::<OwnedUserId>::new(); + while let Some(room_id) = server_rooms.next().await { + let keys_changed = self .services - .state_cache - .server_sees_user(server_name, &user_id)? - { - continue; + .users + .room_keys_changed(room_id, since, None) + .ready_filter(|(user_id, _)| self.services.globals.user_is_local(user_id)); + + pin_mut!(keys_changed); + while let Some((user_id, count)) = keys_changed.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); + if !device_list_changes.insert(user_id.into()) { + continue; + } + + // Empty prev id forces synapse to resync; because synapse resyncs, + // we can just insert placeholder data + let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id: user_id.into(), + device_id: device_id!("placeholder").to_owned(), + device_display_name: Some("Placeholder".to_owned()), + stream_id: uint!(1), + prev_id: Vec::new(), + deleted: None, + keys: None, + }); + + let edu = serde_json::to_vec(&edu).expect("failed to serialize device list update to JSON"); + + events.push(edu); + if events.len() >= SELECT_EDU_LIMIT { + return; + } } + } + } - let presence_event = self - .services - .presence - .from_json_bytes_to_event(&presence_bytes, &user_id)?; - presence_updates.push(PresenceUpdate { - user_id, - presence: presence_event.content.presence, - currently_active: presence_event.content.currently_active.unwrap_or(false), - last_active_ago: presence_event - .content - .last_active_ago - .unwrap_or_else(|| uint!(0)), - status_msg: presence_event.content.status_msg, - }); + /// Look for read receipts in this room + async fn select_edus_receipts( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, + ) { + debug_assert!(events.len() < EDU_LIMIT, "called when edu limit reached"); - if presence_updates.len() >= SELECT_EDU_LIMIT { - break; + let server_rooms = self.services.state_cache.server_rooms(server_name); + + pin_mut!(server_rooms); + let mut num = 0; + let mut receipts = BTreeMap::<OwnedRoomId, ReceiptMap>::new(); + while let Some(room_id) = server_rooms.next().await { + let receipt_map = self + .select_edus_receipts_room(room_id, since, max_edu_count, &mut num) + .await; + + if !receipt_map.read.is_empty() { + receipts.insert(room_id.into(), receipt_map); } } - if !presence_updates.is_empty() { - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + if receipts.is_empty() { + return; } - Ok(true) + let receipt_content = Edu::Receipt(ReceiptContent { + receipts, + }); + + let receipt_content = + serde_json::to_vec(&receipt_content).expect("Failed to serialize Receipt EDU to JSON vec"); + + events.push(receipt_content); } /// Look for read receipts in this room - fn select_edus_receipts( - &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, - ) -> Result<bool> { - for r in self + async fn select_edus_receipts_room( + &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, num: &mut usize, + ) -> ReceiptMap { + let receipts = self .services .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); + .readreceipts_since(room_id, since); + pin_mut!(receipts); + let mut read = BTreeMap::<OwnedUserId, ReceiptData>::new(); + while let Some((user_id, count, read_receipt)) = receipts.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); if !self.services.globals.user_is_local(&user_id) { continue; } - let event = serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { - let mut read = BTreeMap::new(); + let Ok(event) = serde_json::from_str(read_receipt.json().get()) else { + error!(?user_id, ?count, ?read_receipt, "Invalid edu event in read_receipts."); + continue; + }; + + let AnySyncEphemeralRoomEvent::Receipt(r) = event else { + error!(?user_id, ?count, ?event, "Invalid event type in read_receipts"); + continue; + }; - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + let receipt_data = ReceiptData { + data: receipt, + event_ids: vec![event_id.clone()], + }; - let receipt_map = ReceiptMap { - read, - }; + if read.insert(user_id, receipt_data).is_none() { + *num = num.saturating_add(1); + if *num >= SELECT_RECEIPT_LIMIT { + break; + } + } + } - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.to_owned(), receipt_map); + ReceiptMap { + read, + } + } - Edu::Receipt(ReceiptContent { - receipts, - }) - } else { - Error::bad_database("Invalid event type in read_receipts"); + /// Look for presence + async fn select_edus_presence( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, + ) { + debug_assert!(events.len() < EDU_LIMIT, "called when edu limit reached"); + + let presence_since = self.services.presence.presence_since(since); + + pin_mut!(presence_since); + let mut presence_updates = HashMap::<OwnedUserId, PresenceUpdate>::new(); + while let Some((user_id, count, presence_bytes)) = presence_since.next().await { + *max_edu_count = cmp::max(count, *max_edu_count); + if !self.services.globals.user_is_local(user_id) { + continue; + } + + if !self + .services + .state_cache + .server_sees_user(server_name, user_id) + .await + { + continue; + } + + let Ok(presence_event) = self + .services + .presence + .from_json_bytes_to_event(presence_bytes, user_id) + .await + .log_err() + else { continue; }; - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + let update = PresenceUpdate { + user_id: user_id.into(), + presence: presence_event.content.presence, + currently_active: presence_event.content.currently_active.unwrap_or(false), + status_msg: presence_event.content.status_msg, + last_active_ago: presence_event + .content + .last_active_ago + .unwrap_or_else(|| uint!(0)), + }; - if events.len() >= SELECT_EDU_LIMIT { - return Ok(false); + presence_updates.insert(user_id.into(), update); + if presence_updates.len() >= SELECT_PRESENCE_LIMIT { + break; } } - Ok(true) + if presence_updates.is_empty() { + return; + } + + let presence_content = Edu::Presence(PresenceContent { + push: presence_updates.into_values().collect(), + }); + + let presence_content = serde_json::to_vec(&presence_content).expect("failed to serialize Presence EDU to JSON"); + + events.push(presence_content); } async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult { @@ -427,60 +502,62 @@ async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> Sen async fn send_events_dest_appservice( &self, dest: &Destination, id: &str, events: Vec<SendingEvent>, ) -> SendingResult { - let mut pdu_jsons = Vec::new(); + let Some(appservice) = self.services.appservice.get_registration(id).await else { + return Err((dest.clone(), err!(Database(warn!(?id, "Missing appservice registration"))))); + }; + let mut pdu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Pdu(_))) + .count(), + ); + let mut edu_jsons: Vec<RumaEdu> = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Edu(_))) + .count(), + ); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdu_jsons.push(pdu.to_room_event()); + } }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Appservices don't need EDUs (?) and flush only; - // no new content + SendingEvent::Edu(edu) => { + if appservice + .receive_ephemeral + .is_some_and(|receive_edus| receive_edus) + { + if let Ok(edu) = serde_json::from_slice(edu) { + edu_jsons.push(edu); + } + } }, + SendingEvent::Flush => {}, // flush only; no new content } } - //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); + let txn_hash = calculate_hash(events.iter().filter_map(|e| match e { + SendingEvent::Edu(b) => Some(&**b), + SendingEvent::Pdu(b) => Some(b.as_ref()), + SendingEvent::Flush => None, + })); + + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash); + + //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty + // transaction"); let client = &self.services.client.appservice; match appservice::send_request( client, - self.services - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, + appservice, ruma::api::appservice::event::push_events::v1::Request { events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::<Vec<_>>(), - ))) - .into(), - ephemeral: Vec::new(), - to_device: Vec::new(), + txn_id: txn_id.into(), + ephemeral: edu_jsons, + to_device: Vec::new(), // TODO }, ) .await @@ -494,23 +571,17 @@ async fn send_events_dest_appservice( async fn send_events_dest_push( &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>, ) -> SendingResult { - let mut pdus = Vec::new(); + let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { + return Err((dest.clone(), err!(Database(error!(?userid, ?pushkey, "Missing pusher"))))); + }; + let mut pdus = Vec::new(); for event in &events { match event { SendingEvent::Pdu(pdu_id) => { - pdus.push( - self.services - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); + if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await { + pdus.push(pdu); + } }, SendingEvent::Edu(_) | SendingEvent::Flush => { // Push gateways don't need EDUs (?) and flush only; @@ -521,36 +592,25 @@ async fn send_events_dest_push( for pdu in pdus { // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) { - if unsigned.get("redacted_because").is_some() { - continue; - } - } - } - - let Some(pusher) = self - .services - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { + if pdu.contains_unsigned_property("redacted_because", serde_json::Value::is_string) { continue; - }; + } let rules_for_user = self .services .account_data - .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + .get_global(userid, GlobalAccountDataEventType::PushRules) + .await + .map_or_else( + |_| push::Ruleset::server_default(userid), + |ev: PushRulesEvent| ev.content.global, + ); let unread: UInt = self .services .user .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? + .await .try_into() .expect("notification count can't go that high"); @@ -559,7 +619,6 @@ async fn send_events_dest_push( .pusher .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) .await - .map(|_response| dest.clone()) .map_err(|e| (dest.clone(), e)); } @@ -586,21 +645,11 @@ async fn send_events_dest_normal( for event in &events { match event { // TODO: check room version and remove event_id if needed - SendingEvent::Pdu(pdu_id) => pdu_jsons.push( - self.convert_to_outgoing_federation_event( - self.services - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - ), - ), + SendingEvent::Pdu(pdu_id) => { + if let Ok(pdu) = self.services.timeline.get_pdu_json_from_id(pdu_id).await { + pdu_jsons.push(self.convert_to_outgoing_federation_event(pdu).await); + } + }, SendingEvent::Edu(edu) => { if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(raw); @@ -612,22 +661,21 @@ async fn send_events_dest_normal( //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty // transaction"); - let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::<Vec<_>>(), - )); + + let txn_hash = calculate_hash(events.iter().filter_map(|e| match e { + SendingEvent::Edu(b) => Some(&**b), + SendingEvent::Pdu(b) => Some(b.as_ref()), + SendingEvent::Flush => None, + })); + + let txn_id = &*general_purpose::URL_SAFE_NO_PAD.encode(txn_hash); let request = send_transaction_message::v1::Request { origin: self.server.config.server_name.clone(), pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: transaction_id.into(), + transaction_id: txn_id.into(), }; let client = &self.services.client.sender; @@ -639,7 +687,7 @@ async fn send_events_dest_normal( .iter() .filter(|(_, res)| res.is_err()) .for_each( - |(pdu_id, res)| warn!(%transaction_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), + |(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), ); }) .map(|_| dest.clone()) @@ -647,7 +695,7 @@ async fn send_events_dest_normal( } /// This does not return a full `Pdu` it is only to satisfy ruma's types. - pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { + pub async fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { if let Some(unsigned) = pdu_json .get_mut("unsigned") .and_then(|val| val.as_object_mut()) @@ -660,7 +708,7 @@ pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonOb .get("room_id") .and_then(|val| RoomId::parse(val.as_str()?).ok()) { - match self.services.state.get_room_version(&room_id) { + match self.services.state.get_room_version(&room_id).await { Ok(room_version_id) => match room_version_id { RoomVersionId::V1 | RoomVersionId::V2 => {}, _ => _ = pdu_json.remove("event_id"), diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs new file mode 100644 index 0000000000000000000000000000000000000000..1080d79eff88918d0634734d47c2f8215bbbdd74 --- /dev/null +++ b/src/service/server_keys/acquire.rs @@ -0,0 +1,227 @@ +use std::{ + borrow::Borrow, + collections::{BTreeMap, BTreeSet}, + time::Duration, +}; + +use conduit::{debug, debug_error, debug_warn, error, implement, info, result::FlatOk, trace, warn}; +use futures::{stream::FuturesUnordered, StreamExt}; +use ruma::{ + api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName, + OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, +}; +use serde_json::value::RawValue as RawJsonValue; +use tokio::time::{timeout_at, Instant}; + +use super::key_exists; + +type Batch = BTreeMap<OwnedServerName, Vec<OwnedServerSigningKeyId>>; + +#[implement(super::Service)] +pub async fn acquire_events_pubkeys<'a, I>(&self, events: I) +where + I: Iterator<Item = &'a Box<RawJsonValue>> + Send, +{ + type Batch = BTreeMap<OwnedServerName, BTreeSet<OwnedServerSigningKeyId>>; + type Signatures = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, String>>; + + let mut batch = Batch::new(); + events + .cloned() + .map(Raw::<CanonicalJsonObject>::from_json) + .map(|event| event.get_field::<Signatures>("signatures")) + .filter_map(FlatOk::flat_ok) + .flat_map(IntoIterator::into_iter) + .for_each(|(server, sigs)| { + batch.entry(server).or_default().extend(sigs.into_keys()); + }); + + let batch = batch + .iter() + .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow))); + + self.acquire_pubkeys(batch).await; +} + +#[implement(super::Service)] +pub async fn acquire_pubkeys<'a, S, K>(&self, batch: S) +where + S: Iterator<Item = (&'a ServerName, K)> + Send + Clone, + K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone, +{ + let notary_only = self.services.server.config.only_query_trusted_key_servers; + let notary_first_always = self.services.server.config.query_trusted_key_servers_first; + let notary_first_on_join = self + .services + .server + .config + .query_trusted_key_servers_first_on_join; + + let requested_servers = batch.clone().count(); + let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count(); + + debug!("acquire {requested_keys} keys from {requested_servers}"); + + let mut missing = self.acquire_locals(batch).await; + let mut missing_keys = keys_count(&missing); + let mut missing_servers = missing.len(); + if missing_servers == 0 { + return; + } + + info!("{missing_keys} keys for {missing_servers} servers will be acquired"); + + if notary_first_always || notary_first_on_join { + missing = self.acquire_notary(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + warn!("missing {missing_keys} keys for {missing_servers} servers from all notaries first"); + } + + if !notary_only { + missing = self.acquire_origins(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + debug_warn!("missing {missing_keys} keys for {missing_servers} servers unreachable"); + } + + if !notary_first_always && !notary_first_on_join { + missing = self.acquire_notary(missing.into_iter()).await; + missing_keys = keys_count(&missing); + missing_servers = missing.len(); + if missing_keys == 0 { + return; + } + + debug_warn!("still missing {missing_keys} keys for {missing_servers} servers from all notaries."); + } + + if missing_keys > 0 { + warn!( + "did not obtain {missing_keys} keys for {missing_servers} servers out of {requested_keys} total keys for \ + {requested_servers} total servers." + ); + } + + for (server, key_ids) in missing { + debug_warn!(?server, ?key_ids, "missing"); + } +} + +#[implement(super::Service)] +async fn acquire_locals<'a, S, K>(&self, batch: S) -> Batch +where + S: Iterator<Item = (&'a ServerName, K)> + Send, + K: Iterator<Item = &'a ServerSigningKeyId> + Send, +{ + let mut missing = Batch::new(); + for (server, key_ids) in batch { + for key_id in key_ids { + if !self.verify_key_exists(server, key_id).await { + missing + .entry(server.into()) + .or_default() + .push(key_id.into()); + } + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_origins<I>(&self, batch: I) -> Batch +where + I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send, +{ + let timeout = Instant::now() + .checked_add(Duration::from_secs(45)) + .expect("timeout overflows"); + + let mut requests: FuturesUnordered<_> = batch + .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids, timeout)) + .collect(); + + let mut missing = Batch::new(); + while let Some((origin, key_ids)) = requests.next().await { + if !key_ids.is_empty() { + missing.insert(origin, key_ids); + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_origin( + &self, origin: OwnedServerName, mut key_ids: Vec<OwnedServerSigningKeyId>, timeout: Instant, +) -> (OwnedServerName, Vec<OwnedServerSigningKeyId>) { + match timeout_at(timeout, self.server_request(&origin)).await { + Err(e) => debug_warn!(?origin, "timed out: {e}"), + Ok(Err(e)) => debug_error!(?origin, "{e}"), + Ok(Ok(server_keys)) => { + trace!( + %origin, + ?key_ids, + ?server_keys, + "received server_keys" + ); + + self.add_signing_keys(server_keys.clone()).await; + key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + }, + } + + (origin, key_ids) +} + +#[implement(super::Service)] +async fn acquire_notary<I>(&self, batch: I) -> Batch +where + I: Iterator<Item = (OwnedServerName, Vec<OwnedServerSigningKeyId>)> + Send, +{ + let mut missing: Batch = batch.collect(); + for notary in self.services.globals.trusted_servers() { + let missing_keys = keys_count(&missing); + let missing_servers = missing.len(); + debug!("Asking notary {notary} for {missing_keys} missing keys from {missing_servers} servers"); + + let batch = missing + .iter() + .map(|(server, keys)| (server.borrow(), keys.iter().map(Borrow::borrow))); + + match self.batch_notary_request(notary, batch).await { + Err(e) => error!("Failed to contact notary {notary:?}: {e}"), + Ok(results) => { + for server_keys in results { + self.acquire_notary_result(&mut missing, server_keys).await; + } + }, + } + } + + missing +} + +#[implement(super::Service)] +async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSigningKeys) { + let server = &server_keys.server_name; + self.add_signing_keys(server_keys.clone()).await; + + if let Some(key_ids) = missing.get_mut(server) { + key_ids.retain(|key_id| key_exists(&server_keys, key_id)); + if key_ids.is_empty() { + missing.remove(server); + } + } +} + +fn keys_count(batch: &Batch) -> usize { batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count() } diff --git a/src/service/server_keys/get.rs b/src/service/server_keys/get.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc4627f7af2564509c6295cdf2c15c3b01e11fd2 --- /dev/null +++ b/src/service/server_keys/get.rs @@ -0,0 +1,117 @@ +use std::borrow::Borrow; + +use conduit::{implement, Err, Result}; +use ruma::{api::federation::discovery::VerifyKey, CanonicalJsonObject, RoomVersionId, ServerName, ServerSigningKeyId}; + +use super::{extract_key, PubKeyMap, PubKeys}; + +#[implement(super::Service)] +pub async fn get_event_keys(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> Result<PubKeyMap> { + use ruma::signatures::required_keys; + + let required = match required_keys(object, version) { + Ok(required) => required, + Err(e) => return Err!(BadServerResponse("Failed to determine keys required to verify: {e}")), + }; + + let batch = required + .iter() + .map(|(s, ids)| (s.borrow(), ids.iter().map(Borrow::borrow))); + + Ok(self.get_pubkeys(batch).await) +} + +#[implement(super::Service)] +pub async fn get_pubkeys<'a, S, K>(&self, batch: S) -> PubKeyMap +where + S: Iterator<Item = (&'a ServerName, K)> + Send, + K: Iterator<Item = &'a ServerSigningKeyId> + Send, +{ + let mut keys = PubKeyMap::new(); + for (server, key_ids) in batch { + let pubkeys = self.get_pubkeys_for(server, key_ids).await; + keys.insert(server.into(), pubkeys); + } + + keys +} + +#[implement(super::Service)] +pub async fn get_pubkeys_for<'a, I>(&self, origin: &ServerName, key_ids: I) -> PubKeys +where + I: Iterator<Item = &'a ServerSigningKeyId> + Send, +{ + let mut keys = PubKeys::new(); + for key_id in key_ids { + if let Ok(verify_key) = self.get_verify_key(origin, key_id).await { + keys.insert(key_id.into(), verify_key.key); + } + } + + keys +} + +#[implement(super::Service)] +pub async fn get_verify_key(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> { + let notary_first = self.services.server.config.query_trusted_key_servers_first; + let notary_only = self.services.server.config.only_query_trusted_key_servers; + + if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) { + return Ok(result); + } + + if notary_first { + if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await { + return Ok(result); + } + } + + if !notary_only { + if let Ok(result) = self.get_verify_key_from_origin(origin, key_id).await { + return Ok(result); + } + } + + if !notary_first { + if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await { + return Ok(result); + } + } + + Err!(BadServerResponse(debug_error!( + ?key_id, + ?origin, + "Failed to fetch federation signing-key" + ))) +} + +#[implement(super::Service)] +async fn get_verify_key_from_notaries(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> { + for notary in self.services.globals.trusted_servers() { + if let Ok(server_keys) = self.notary_request(notary, origin).await { + for server_key in server_keys.clone() { + self.add_signing_keys(server_key).await; + } + + for server_key in server_keys { + if let Some(result) = extract_key(server_key, key_id) { + return Ok(result); + } + } + } + } + + Err!(Request(NotFound("Failed to fetch signing-key from notaries"))) +} + +#[implement(super::Service)] +async fn get_verify_key_from_origin(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> Result<VerifyKey> { + if let Ok(server_key) = self.server_request(origin).await { + self.add_signing_keys(server_key.clone()).await; + if let Some(result) = extract_key(server_key, key_id) { + return Ok(result); + } + } + + Err!(Request(NotFound("Failed to fetch signing-key from origin"))) +} diff --git a/src/service/server_keys/keypair.rs b/src/service/server_keys/keypair.rs new file mode 100644 index 0000000000000000000000000000000000000000..31a24cdf387a447d12916fc908d3bfb585cd7947 --- /dev/null +++ b/src/service/server_keys/keypair.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; + +use conduit::{debug, debug_info, err, error, utils, utils::string_from_bytes, Result}; +use database::Database; +use ruma::{api::federation::discovery::VerifyKey, serde::Base64, signatures::Ed25519KeyPair}; + +use super::VerifyKeys; + +pub(super) fn init(db: &Arc<Database>) -> Result<(Box<Ed25519KeyPair>, VerifyKeys)> { + let keypair = load(db).inspect_err(|_e| { + error!("Keypair invalid. Deleting..."); + remove(db); + })?; + + let verify_key = VerifyKey { + key: Base64::new(keypair.public_key().to_vec()), + }; + + let id = format!("ed25519:{}", keypair.version()); + let verify_keys: VerifyKeys = [(id.try_into()?, verify_key)].into(); + + Ok((keypair, verify_keys)) +} + +fn load(db: &Arc<Database>) -> Result<Box<Ed25519KeyPair>> { + let (version, key) = db["global"] + .get_blocking(b"keypair") + .map(|ref val| { + // database deserializer is having trouble with this so it's manual for now + let mut elems = val.split(|&b| b == b'\xFF'); + let vlen = elems.next().expect("invalid keypair entry").len(); + let ver = string_from_bytes(&val[..vlen]).expect("invalid keypair version"); + let der = val[vlen.saturating_add(1)..].to_vec(); + debug!("Found existing Ed25519 keypair: {ver:?}"); + (ver, der) + }) + .or_else(|e| { + assert!(e.is_not_found(), "unexpected error fetching keypair"); + create(db) + })?; + + let key = + Ed25519KeyPair::from_der(&key, version).map_err(|e| err!("Failed to load ed25519 keypair from der: {e:?}"))?; + + Ok(Box::new(key)) +} + +fn create(db: &Arc<Database>) -> Result<(String, Vec<u8>)> { + let keypair = Ed25519KeyPair::generate().map_err(|e| err!("Failed to generate new ed25519 keypair: {e:?}"))?; + + let id = utils::rand::string(8); + debug_info!("Generated new Ed25519 keypair: {id:?}"); + + let value: (String, Vec<u8>) = (id, keypair.to_vec()); + db["global"].raw_put(b"keypair", &value); + + Ok(value) +} + +#[inline] +fn remove(db: &Arc<Database>) { + let global = &db["global"]; + global.remove(b"keypair"); +} diff --git a/src/service/server_keys/mod.rs b/src/service/server_keys/mod.rs index a565e5009342a485f386da91c0c76c2976c52dfa..08bcefb630af43694beb57d4205ac67a8eab3f16 100644 --- a/src/service/server_keys/mod.rs +++ b/src/service/server_keys/mod.rs @@ -1,44 +1,70 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - sync::Arc, - time::{Duration, SystemTime}, +mod acquire; +mod get; +mod keypair; +mod request; +mod sign; +mod verify; + +use std::{collections::BTreeMap, sync::Arc, time::Duration}; + +use conduit::{ + implement, + utils::{timepoint_from_now, IterStream}, + Result, Server, }; - -use conduit::{debug, debug_error, debug_warn, err, error, info, trace, warn, Err, Result}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use database::{Deserialized, Json, Map}; +use futures::StreamExt; use ruma::{ - api::federation::{ - discovery::{ - get_remote_server_keys, - get_remote_server_keys_batch::{self, v2::QueryCriteria}, - get_server_keys, - }, - membership::create_join_event, - }, - serde::Base64, - CanonicalJsonObject, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId, - RoomVersionId, ServerName, + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + serde::Raw, + signatures::{Ed25519KeyPair, PublicKeyMap, PublicKeySet}, + CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, RoomVersionId, ServerName, + ServerSigningKeyId, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::{RwLock, RwLockWriteGuard}; use crate::{globals, sending, Dep}; pub struct Service { + keypair: Box<Ed25519KeyPair>, + verify_keys: VerifyKeys, + minimum_valid: Duration, services: Services, + db: Data, } struct Services { globals: Dep<globals::Service>, sending: Dep<sending::Service>, + server: Arc<Server>, +} + +struct Data { + server_signingkeys: Arc<Map>, } +pub type VerifyKeys = BTreeMap<OwnedServerSigningKeyId, VerifyKey>; +pub type PubKeyMap = PublicKeyMap; +pub type PubKeys = PublicKeySet; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let minimum_valid = Duration::from_secs(3600); + + let (keypair, verify_keys) = keypair::init(args.db)?; + debug_assert!(verify_keys.len() == 1, "only one active verify_key supported"); + Ok(Arc::new(Self { + keypair, + verify_keys, + minimum_valid, services: Services { globals: args.depend::<globals::Service>("globals"), sending: args.depend::<sending::Service>("sending"), + server: args.server.clone(), + }, + db: Data { + server_signingkeys: args.db["server_signingkeys"].clone(), }, })) } @@ -46,525 +72,136 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub async fn fetch_required_signing_keys<'a, E>( - &'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> - where - E: IntoIterator<Item = &'a BTreeMap<String, CanonicalJsonValue>> + Send, - { - let mut server_key_ids = HashMap::new(); - for event in events { - for (signature_server, signature) in event - .get("signatures") - .ok_or(err!(BadServerResponse("No signatures in server response pdu.")))? - .as_object() - .ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))? - { - let signature_object = signature.as_object().ok_or(err!(BadServerResponse( - "Invalid signatures content object in server response pdu.", - )))?; - - for signature_id in signature_object.keys() { - server_key_ids - .entry(signature_server.clone()) - .or_insert_with(HashSet::new) - .insert(signature_id.clone()); - } - } - } - - if server_key_ids.is_empty() { - // Nothing to do, can exit early - trace!("server_key_ids is empty, not fetching any keys"); - return Ok(()); - } - - trace!( - "Fetch keys for {}", - server_key_ids - .keys() - .cloned() - .collect::<Vec<_>>() - .join(", ") - ); +#[implement(Service)] +#[inline] +pub fn keypair(&self) -> &Ed25519KeyPair { &self.keypair } + +#[implement(Service)] +#[inline] +pub fn active_key_id(&self) -> &ServerSigningKeyId { self.active_verify_key().0 } + +#[implement(Service)] +#[inline] +pub fn active_verify_key(&self) -> (&ServerSigningKeyId, &VerifyKey) { + debug_assert!(self.verify_keys.len() <= 1, "more than one active verify_key"); + self.verify_keys + .iter() + .next() + .map(|(id, key)| (id.as_ref(), key)) + .expect("missing active verify_key") +} - let mut server_keys: FuturesUnordered<_> = server_key_ids - .into_iter() - .map(|(signature_server, signature_ids)| async { - let fetch_res = self - .fetch_signing_keys_for_server( - signature_server.as_str().try_into().map_err(|e| { - ( - signature_server.clone(), - err!(BadServerResponse( - "Invalid servername in signatures of server response pdu: {e:?}" - )), - ) - })?, - signature_ids.into_iter().collect(), // HashSet to Vec - ) - .await; - - match fetch_res { - Ok(keys) => Ok((signature_server, keys)), - Err(e) => { - debug_error!( - "Signature verification failed: Could not fetch signing key for {signature_server}: {e}", - ); - Err((signature_server, e)) - }, - } - }) - .collect(); - - while let Some(fetch_res) = server_keys.next().await { - match fetch_res { - Ok((signature_server, keys)) => { - pub_key_map - .write() - .await - .insert(signature_server.clone(), keys); - }, - Err((signature_server, e)) => { - debug_warn!("Failed to fetch keys for {signature_server}: {e:?}"); - }, - } - } +#[implement(Service)] +async fn add_signing_keys(&self, new_keys: ServerSigningKeys) { + let origin = &new_keys.server_name; + + // (timo) Not atomic, but this is not critical + let mut keys: ServerSigningKeys = self + .db + .server_signingkeys + .get(origin) + .await + .deserialized() + .unwrap_or_else(|_| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + keys.verify_keys.extend(new_keys.verify_keys); + keys.old_verify_keys.extend(new_keys.old_verify_keys); + self.db.server_signingkeys.raw_put(origin, Json(&keys)); +} - Ok(()) - } +#[implement(Service)] +pub async fn required_keys_exist(&self, object: &CanonicalJsonObject, version: &RoomVersionId) -> bool { + use ruma::signatures::required_keys; - // Gets a list of servers for which we don't have the signing key yet. We go - // over the PDUs and either cache the key or add it to the list that needs to be - // retrieved. - async fn get_server_keys_from_cache( - &self, pdu: &RawJsonValue, - servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, - _room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - debug_error!("Invalid PDU in server response: {pdu:#?}"); - err!(BadServerResponse(error!("Invalid PDU in server response: {e:?}"))) - })?; - - let signatures = value - .get("signatures") - .ok_or(err!(BadServerResponse("No signatures in server response pdu.")))? - .as_object() - .ok_or(err!(BadServerResponse("Invalid signatures object in server response pdu.")))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(err!(BadServerResponse( - "Invalid signatures content object in server response pdu.", - )))?; - - let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); - - let contains_all_ids = - |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { - err!(BadServerResponse( - "Invalid servername in signatures of server response pdu: {e:?}" - )) - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - debug!("Loading signing keys for {origin}"); - let result: BTreeMap<_, _> = self - .services - .globals - .verify_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - debug_warn!("Signing key not loaded for {origin}"); - servers.insert(origin.to_owned(), BTreeMap::new()); - } - - pub_key_map.insert(origin.to_string(), result); - } + let Ok(required_keys) = required_keys(object, version) else { + return false; + }; - Ok(()) - } + required_keys + .iter() + .flat_map(|(server, key_ids)| key_ids.iter().map(move |key_id| (server, key_id))) + .stream() + .all(|(server, key_id)| self.verify_key_exists(server, key_id)) + .await +} - /// Batch requests homeserver signing keys from trusted notary key servers - /// (`trusted_servers` config option) - async fn batch_request_signing_keys( - &self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> { - for server in self.services.globals.trusted_servers() { - debug!("Asking batch signing keys from trusted server {server}"); - match self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - Ok(keys) => { - debug!("Got signing keys: {keys:?}"); - let mut pkm = pub_key_map.write().await; - for k in keys.server_keys { - let k = match k.deserialize() { - Ok(key) => key, - Err(e) => { - warn!( - "Received error {e} while fetching keys from trusted server {server}: {:#?}", - k.into_json() - ); - continue; - }, - }; - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = self - .services - .globals - .db - .add_signing_key(&k.server_name, k.clone())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::<BTreeMap<_, _>>(); - - pkm.insert(k.server_name.to_string(), result); - } - }, - Err(e) => error!( - "Failed sending batched key request to trusted key server {server} for the remote servers \ - {servers:?}: {e}" - ), - } +#[implement(Service)] +pub async fn verify_key_exists(&self, origin: &ServerName, key_id: &ServerSigningKeyId) -> bool { + type KeysMap<'a> = BTreeMap<&'a ServerSigningKeyId, &'a RawJsonValue>; + + let Ok(keys) = self + .db + .server_signingkeys + .get(origin) + .await + .deserialized::<Raw<ServerSigningKeys>>() + else { + return false; + }; + + if let Ok(Some(verify_keys)) = keys.get_field::<KeysMap<'_>>("verify_keys") { + if verify_keys.contains_key(key_id) { + return true; } - - Ok(()) } - /// Requests multiple homeserver signing keys from individual servers (not - /// trused notary servers) - async fn request_signing_keys( - &self, servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> { - debug!("Asking individual servers for signing keys: {servers:?}"); - let mut futures: FuturesUnordered<_> = servers - .into_keys() - .map(|server| async move { - ( - self.services - .sending - .send_federation_request(&server, get_server_keys::v2::Request::new()) - .await, - server, - ) - }) - .collect(); - - while let Some(result) = futures.next().await { - debug!("Received new Future result"); - if let (Ok(get_keys_response), origin) = result { - debug!("Result is from {origin}"); - if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = self - .services - .globals - .db - .add_signing_key(&origin, key)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - pub_key_map.write().await.insert(origin.to_string(), result); - } - } - debug!("Done handling Future result"); + if let Ok(Some(old_verify_keys)) = keys.get_field::<KeysMap<'_>>("old_verify_keys") { + if old_verify_keys.contains_key(key_id) { + return true; } - - Ok(()) } - pub async fn fetch_join_signing_keys( - &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - ) -> Result<()> { - let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new(); - - { - let mut pkm = pub_key_map.write().await; - - // Try to fetch keys, failure is okay. Servers we couldn't find in the cache - // will be added to `servers` - for pdu in event - .room_state - .state - .iter() - .chain(&event.room_state.auth_chain) - { - if let Err(error) = self - .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) - .await - { - debug!(%error, "failed to get server keys from cache"); - }; - } - - drop(pkm); - }; - - if servers.is_empty() { - trace!("We had all keys cached locally, not fetching any keys from remote servers"); - return Ok(()); - } - - if self.services.globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ - homeserver signing keys." - ); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - debug!("Trusted server supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - debug!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - } else { - debug!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - debug!("Individual homeservers supplied all signing keys, no more keys to fetch"); - return Ok(()); - } + false +} - debug!("Remaining servers left the individual homeservers did not provide: {servers:?}"); +#[implement(Service)] +pub async fn verify_keys_for(&self, origin: &ServerName) -> VerifyKeys { + let mut keys = self + .signing_keys_for(origin) + .await + .map(|keys| merge_old_keys(keys).verify_keys) + .unwrap_or(BTreeMap::new()); - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - } + if self.services.globals.server_is_ours(origin) { + keys.extend(self.verify_keys.clone().into_iter()); + } - debug!("Search for signing keys done"); + keys +} - /*if servers.is_empty() { - warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); - }*/ +#[implement(Service)] +pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> { + self.db.server_signingkeys.get(origin).await.deserialized() +} - Ok(()) - } +#[implement(Service)] +fn minimum_valid_ts(&self) -> MilliSecondsSinceUnixEpoch { + let timepoint = timepoint_from_now(self.minimum_valid).expect("SystemTime should not overflow"); + MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow") +} - /// Search the DB for the signing keys of the given server, if we don't have - /// them fetch them from the server and save to our DB. - #[tracing::instrument(skip_all)] - pub async fn fetch_signing_keys_for_server( - &self, origin: &ServerName, signature_ids: Vec<String>, - ) -> Result<BTreeMap<String, Base64>> { - let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - - let mut result: BTreeMap<_, _> = self - .services - .globals - .verify_keys_for(origin)? +fn merge_old_keys(mut keys: ServerSigningKeys) -> ServerSigningKeys { + keys.verify_keys.extend( + keys.old_verify_keys + .clone() .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); + .map(|(key_id, old)| (key_id, VerifyKey::new(old.key))), + ); - if contains_all_ids(&result) { - trace!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); - return Ok(result); - } + keys +} - // i didnt split this out into their own functions because it's relatively small - if self.services.globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ - keys" - ); - - for server in self.services.globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::<Vec<_>>() - }) { - debug!("Got signing keys: {:?}", server_keys); - for k in server_keys { - self.services - .globals - .db - .add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = self - .services - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - self.services - .globals - .db - .add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - } else { - info!("query_trusted_key_servers_first is set to false, querying {origin} first"); - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = self - .services - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - self.services - .globals - .db - .add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in self.services.globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = self - .services - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::<Vec<_>>() - }) { - debug!("Got signing keys: {server_keys:?}"); - for k in server_keys { - self.services - .globals - .db - .add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - } +fn extract_key(mut keys: ServerSigningKeys, key_id: &ServerSigningKeyId) -> Option<VerifyKey> { + keys.verify_keys.remove(key_id).or_else(|| { + keys.old_verify_keys + .remove(key_id) + .map(|old| VerifyKey::new(old.key)) + }) +} - Err!(BadServerResponse(warn!("Failed to find public key for server {origin:?}"))) - } +fn key_exists(keys: &ServerSigningKeys, key_id: &ServerSigningKeyId) -> bool { + keys.verify_keys.contains_key(key_id) || keys.old_verify_keys.contains_key(key_id) } diff --git a/src/service/server_keys/request.rs b/src/service/server_keys/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..7078f7cd04ece5693b1c3aef8d5acf22f7f8c700 --- /dev/null +++ b/src/service/server_keys/request.rs @@ -0,0 +1,120 @@ +use std::{collections::BTreeMap, fmt::Debug}; + +use conduit::{debug, implement, Err, Result}; +use ruma::{ + api::federation::discovery::{ + get_remote_server_keys, + get_remote_server_keys_batch::{self, v2::QueryCriteria}, + get_server_keys, ServerSigningKeys, + }, + OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, +}; + +#[implement(super::Service)] +pub(super) async fn batch_notary_request<'a, S, K>( + &self, notary: &ServerName, batch: S, +) -> Result<Vec<ServerSigningKeys>> +where + S: Iterator<Item = (&'a ServerName, K)> + Send, + K: Iterator<Item = &'a ServerSigningKeyId> + Send, +{ + use get_remote_server_keys_batch::v2::Request; + type RumaBatch = BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>; + + let criteria = QueryCriteria { + minimum_valid_until_ts: Some(self.minimum_valid_ts()), + }; + + let mut server_keys = batch.fold(RumaBatch::new(), |mut batch, (server, key_ids)| { + batch + .entry(server.into()) + .or_default() + .extend(key_ids.map(|key_id| (key_id.into(), criteria.clone()))); + + batch + }); + + debug_assert!(!server_keys.is_empty(), "empty batch request to notary"); + + let mut results = Vec::new(); + while let Some(batch) = server_keys + .keys() + .rev() + .take(self.services.server.config.trusted_server_batch_size) + .last() + .cloned() + { + let request = Request { + server_keys: server_keys.split_off(&batch), + }; + + debug!( + ?notary, + ?batch, + remaining = %server_keys.len(), + requesting = ?request.server_keys.keys(), + "notary request" + ); + + let response = self + .services + .sending + .send_synapse_request(notary, request) + .await? + .server_keys + .into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok); + + results.extend(response); + } + + Ok(results) +} + +#[implement(super::Service)] +pub async fn notary_request( + &self, notary: &ServerName, target: &ServerName, +) -> Result<impl Iterator<Item = ServerSigningKeys> + Clone + Debug + Send> { + use get_remote_server_keys::v2::Request; + + let request = Request { + server_name: target.into(), + minimum_valid_until_ts: self.minimum_valid_ts(), + }; + + let response = self + .services + .sending + .send_federation_request(notary, request) + .await? + .server_keys + .into_iter() + .map(|key| key.deserialize()) + .filter_map(Result::ok); + + Ok(response) +} + +#[implement(super::Service)] +pub async fn server_request(&self, target: &ServerName) -> Result<ServerSigningKeys> { + use get_server_keys::v2::Request; + + let server_signing_key = self + .services + .sending + .send_federation_request(target, Request::new()) + .await + .map(|response| response.server_key) + .and_then(|key| key.deserialize().map_err(Into::into))?; + + if server_signing_key.server_name != target { + return Err!(BadServerResponse(debug_warn!( + requested = ?target, + response = ?server_signing_key.server_name, + "Server responded with bogus server_name" + ))); + } + + Ok(server_signing_key) +} diff --git a/src/service/server_keys/sign.rs b/src/service/server_keys/sign.rs new file mode 100644 index 0000000000000000000000000000000000000000..28fd7e8038edc14a6cf154a4babeb26e85c5cc0c --- /dev/null +++ b/src/service/server_keys/sign.rs @@ -0,0 +1,18 @@ +use conduit::{implement, Result}; +use ruma::{CanonicalJsonObject, RoomVersionId}; + +#[implement(super::Service)] +pub fn sign_json(&self, object: &mut CanonicalJsonObject) -> Result { + use ruma::signatures::sign_json; + + let server_name = self.services.globals.server_name().as_str(); + sign_json(server_name, self.keypair(), object).map_err(Into::into) +} + +#[implement(super::Service)] +pub fn hash_and_sign_event(&self, object: &mut CanonicalJsonObject, room_version: &RoomVersionId) -> Result { + use ruma::signatures::hash_and_sign_event; + + let server_name = self.services.globals.server_name().as_str(); + hash_and_sign_event(server_name, self.keypair(), object, room_version).map_err(Into::into) +} diff --git a/src/service/server_keys/verify.rs b/src/service/server_keys/verify.rs new file mode 100644 index 0000000000000000000000000000000000000000..c836e324a52160eec1d8a279373968bf4c02cccf --- /dev/null +++ b/src/service/server_keys/verify.rs @@ -0,0 +1,53 @@ +use conduit::{implement, pdu::gen_event_id_canonical_json, Err, Result}; +use ruma::{signatures::Verified, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, RoomVersionId}; +use serde_json::value::RawValue as RawJsonValue; + +#[implement(super::Service)] +pub async fn validate_and_add_event_id( + &self, pdu: &RawJsonValue, room_version: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?; + if let Err(e) = self.verify_event(&value, Some(room_version)).await { + return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); + } + + value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into())); + + Ok((event_id, value)) +} + +#[implement(super::Service)] +pub async fn validate_and_add_event_id_no_fetch( + &self, pdu: &RawJsonValue, room_version: &RoomVersionId, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?; + if !self.required_keys_exist(&value, room_version).await { + return Err!(BadServerResponse(debug_warn!( + "Event {event_id} cannot be verified: missing keys." + ))); + } + + if let Err(e) = self.verify_event(&value, Some(room_version)).await { + return Err!(BadServerResponse(debug_error!("Event {event_id} failed verification: {e:?}"))); + } + + value.insert("event_id".into(), CanonicalJsonValue::String(event_id.as_str().into())); + + Ok((event_id, value)) +} + +#[implement(super::Service)] +pub async fn verify_event( + &self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>, +) -> Result<Verified> { + let room_version = room_version.unwrap_or(&RoomVersionId::V11); + let keys = self.get_event_keys(event, room_version).await?; + ruma::signatures::verify_event(&keys, event, room_version).map_err(Into::into) +} + +#[implement(super::Service)] +pub async fn verify_json(&self, event: &CanonicalJsonObject, room_version: Option<&RoomVersionId>) -> Result { + let room_version = room_version.unwrap_or(&RoomVersionId::V11); + let keys = self.get_event_keys(event, room_version).await?; + ruma::signatures::verify_json(&keys, event.clone()).map_err(Into::into) +} diff --git a/src/service/service.rs b/src/service/service.rs index 635f782ea6b9525471a46c97092d38e568423a8d..7ec2ea0febd1901186c46ef55ccb93fb588b69a5 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -7,7 +7,7 @@ }; use async_trait::async_trait; -use conduit::{err, error::inspect_log, utils::string::split_once_infallible, Err, Result, Server}; +use conduit::{err, error::inspect_log, utils::string::SplitInfallible, Err, Result, Server}; use database::Database; /// Abstract interface for a Service @@ -51,7 +51,7 @@ pub(crate) struct Args<'a> { /// Dep is a reference to a service used within another service. /// Circular-dependencies between services require this indirection. -pub(crate) struct Dep<T> { +pub(crate) struct Dep<T: Service + Send + Sync> { dep: OnceLock<Arc<T>>, service: Weak<Map>, name: &'static str, @@ -62,25 +62,48 @@ pub(crate) struct Dep<T> { pub(crate) type MapVal = (Weak<dyn Service>, Weak<dyn Any + Send + Sync>); pub(crate) type MapKey = String; -impl<T: Send + Sync + 'static> Deref for Dep<T> { +/// SAFETY: Workaround for a compiler limitation (or bug) where it is Hard to +/// prove the Sync'ness of Dep because services contain circular references +/// to other services through Dep's. The Sync'ness of Dep can still be +/// proved without unsafety by declaring the crate-attribute #![recursion_limit +/// = "192"] but this may take a while. Re-evaluate this when a new trait-solver +/// (such as Chalk) becomes available. +unsafe impl<T: Service> Sync for Dep<T> {} + +/// SAFETY: Ancillary to unsafe impl Sync; while this is not needed to prevent +/// violating the recursion_limit, the trait-solver still spends an inordinate +/// amount of time to prove this. +unsafe impl<T: Service> Send for Dep<T> {} + +impl<T: Service + Send + Sync> Deref for Dep<T> { type Target = Arc<T>; /// Dereference a dependency. The dependency must be ready or panics. + #[inline] fn deref(&self) -> &Self::Target { - self.dep.get_or_init(|| { - let service = self - .service - .upgrade() - .expect("services map exists for dependency initialization."); - - require::<T>(&service, self.name) - }) + self.dep.get_or_init( + #[inline(never)] + || self.init(), + ) + } +} + +impl<T: Service + Send + Sync> Dep<T> { + #[inline] + fn init(&self) -> Arc<T> { + let service = self + .service + .upgrade() + .expect("services map exists for dependency initialization."); + + require::<T>(&service, self.name) } } impl<'a> Args<'a> { /// Create a lazy-reference to a service when constructing another Service. - pub(crate) fn depend<T: Send + Sync + 'a + 'static>(&'a self, name: &'static str) -> Dep<T> { + #[inline] + pub(crate) fn depend<T: Service>(&'a self, name: &'static str) -> Dep<T> { Dep::<T> { dep: OnceLock::new(), service: Arc::downgrade(self.service), @@ -90,17 +113,14 @@ pub(crate) fn depend<T: Send + Sync + 'a + 'static>(&'a self, name: &'static str /// Create a reference immediately to a service when constructing another /// Service. The other service must be constructed. - pub(crate) fn require<T: Send + Sync + 'a + 'static>(&'a self, name: &'static str) -> Arc<T> { - require::<T>(self.service, name) - } + #[inline] + pub(crate) fn require<T: Service>(&'a self, name: &str) -> Arc<T> { require::<T>(self.service, name) } } /// Reference a Service by name. Panics if the Service does not exist or was /// incorrectly cast. -pub(crate) fn require<'a, 'b, T>(map: &'b Map, name: &'a str) -> Arc<T> -where - T: Send + Sync + 'a + 'b + 'static, -{ +#[inline] +fn require<T: Service>(map: &Map, name: &str) -> Arc<T> { try_get::<T>(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") @@ -112,9 +132,9 @@ pub(crate) fn require<'a, 'b, T>(map: &'b Map, name: &'a str) -> Arc<T> /// # Panics /// Incorrect type is not a silent failure (None) as the type never has a reason /// to be incorrect. -pub(crate) fn get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Option<Arc<T>> +pub(crate) fn get<T>(map: &Map, name: &str) -> Option<Arc<T>> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") @@ -129,9 +149,9 @@ pub(crate) fn get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Option<Arc<T>> /// Reference a Service by name. Returns Err if the Service does not exist or /// was incorrectly cast. -pub(crate) fn try_get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Result<Arc<T>> +pub(crate) fn try_get<T>(map: &Map, name: &str) -> Result<Arc<T>> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { map.read() .expect("locked for reading") @@ -152,4 +172,4 @@ pub(crate) fn try_get<'a, 'b, T>(map: &'b Map, name: &'a str) -> Result<Arc<T>> /// Utility for service implementations; see Service::name() in the trait. #[inline] -pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } +pub(crate) fn make_name(module_path: &str) -> &str { module_path.split_once_infallible("::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index 8e69cdbb622328f301c5e197b344c788f6c9635a..b86e7a72192676387c06b139edc2f276157281b1 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -14,7 +14,7 @@ manager::Manager, media, presence, pusher, resolver, rooms, sending, server_keys, service, service::{Args, Map, Service}, - transaction_ids, uiaa, updates, users, + sync, transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -32,6 +32,7 @@ pub struct Services { pub rooms: rooms::Service, pub sending: Arc<sending::Service>, pub server_keys: Arc<server_keys::Service>, + pub sync: Arc<sync::Service>, pub transaction_ids: Arc<transaction_ids::Service>, pub uiaa: Arc<uiaa::Service>, pub updates: Arc<updates::Service>, @@ -96,6 +97,7 @@ macro_rules! build { }, sending: build!(sending::Service), server_keys: build!(server_keys::Service), + sync: build!(sync::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), updates: build!(updates::Service), @@ -111,8 +113,8 @@ macro_rules! build { pub async fn start(self: &Arc<Self>) -> Result<Arc<Self>> { debug_info!("Starting services..."); - self.admin.set_services(&Some(Arc::clone(self))); - globals::migrations::migrations(self).await?; + self.admin.set_services(Some(Arc::clone(self)).as_ref()); + super::migrations::migrations(self).await?; self.manager .lock() .await @@ -121,6 +123,14 @@ pub async fn start(self: &Arc<Self>) -> Result<Arc<Self>> { .start() .await?; + // set the server user as online + if self.server.config.allow_local_presence { + _ = self + .presence + .ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Online) + .await; + } + debug_info!("Services startup complete."); Ok(Arc::clone(self)) } @@ -128,12 +138,20 @@ pub async fn start(self: &Arc<Self>) -> Result<Arc<Self>> { pub async fn stop(&self) { info!("Shutting down services..."); + // set the server user as offline + if self.server.config.allow_local_presence { + _ = self + .presence + .ping_presence(&self.globals.server_user, &ruma::presence::PresenceState::Offline) + .await; + } + self.interrupt(); if let Some(manager) = self.manager.lock().await.as_ref() { manager.stop().await; } - self.admin.set_services(&None); + self.admin.set_services(None); debug_info!("Services shutdown complete."); } @@ -193,16 +211,18 @@ fn interrupt(&self) { } } - pub fn try_get<'a, 'b, T>(&'b self, name: &'a str) -> Result<Arc<T>> + #[inline] + pub fn try_get<T>(&self, name: &str) -> Result<Arc<T>> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::try_get::<T>(&self.service, name) } - pub fn get<'a, 'b, T>(&'b self, name: &'a str) -> Option<Arc<T>> + #[inline] + pub fn get<T>(&self, name: &str) -> Option<Arc<T>> where - T: Send + Sync + 'a + 'b + 'static, + T: Any + Send + Sync + Sized, { service::get::<T>(&self.service, name) } diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..f1a6ae75e695be58b542d16ff427b5e699a46722 --- /dev/null +++ b/src/service/sync/mod.rs @@ -0,0 +1,280 @@ +mod watch; + +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::{Arc, Mutex, Mutex as StdMutex}, +}; + +use conduit::{Result, Server}; +use database::Map; +use ruma::{ + api::client::sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + OwnedDeviceId, OwnedRoomId, OwnedUserId, +}; + +use crate::{rooms, Dep}; + +pub struct Service { + db: Data, + services: Services, + connections: DbConnections, +} + +pub struct Data { + todeviceid_events: Arc<Map>, + userroomid_joined: Arc<Map>, + userroomid_invitestate: Arc<Map>, + userroomid_leftstate: Arc<Map>, + userroomid_notificationcount: Arc<Map>, + userroomid_highlightcount: Arc<Map>, + pduid_pdu: Arc<Map>, + keychangeid_userid: Arc<Map>, + roomusertype_roomuserdataid: Arc<Map>, + readreceiptid_readreceipt: Arc<Map>, + userid_lastonetimekeyupdate: Arc<Map>, +} + +struct Services { + server: Arc<Server>, + short: Dep<rooms::short::Service>, + state_cache: Dep<rooms::state_cache::Service>, + typing: Dep<rooms::typing::Service>, +} + +struct SlidingSyncCache { + lists: BTreeMap<String, SyncRequestList>, + subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, + known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, // For every room, the roomsince number + extensions: ExtensionsConfig, +} + +type DbConnections = Mutex<BTreeMap<DbConnectionsKey, DbConnectionsVal>>; +type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); +type DbConnectionsVal = Arc<Mutex<SlidingSyncCache>>; + +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + Ok(Arc::new(Self { + db: Data { + todeviceid_events: args.db["todeviceid_events"].clone(), + userroomid_joined: args.db["userroomid_joined"].clone(), + userroomid_invitestate: args.db["userroomid_invitestate"].clone(), + userroomid_leftstate: args.db["userroomid_leftstate"].clone(), + userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), + userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), + pduid_pdu: args.db["pduid_pdu"].clone(), + keychangeid_userid: args.db["keychangeid_userid"].clone(), + roomusertype_roomuserdataid: args.db["roomusertype_roomuserdataid"].clone(), + readreceiptid_readreceipt: args.db["readreceiptid_readreceipt"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + }, + services: Services { + server: args.server.clone(), + short: args.depend::<rooms::short::Service>("rooms::short"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + typing: args.depend::<rooms::typing::Service>("rooms::typing"), + }, + connections: StdMutex::new(BTreeMap::new()), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { + self.connections + .lock() + .unwrap() + .contains_key(&(user_id, device_id, conn_id)) + } + + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { + self.connections + .lock() + .expect("locked") + .remove(&(user_id, device_id, conn_id)); + } + + pub fn update_sync_request_with_cache( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, + ) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; + + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort.clone_from(&cached_list.sort); + }; + if list.room_details.required_state.is_empty() { + list.room_details + .required_state + .clone_from(&cached_list.room_details.required_state); + }; + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or_else(|| cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + }, + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (Some(list_filters), _) => list.filters = Some(list_filters.clone()), + (..) => {}, + } + if list.bump_event_types.is_empty() { + list.bump_event_types + .clone_from(&cached_list.bump_event_types); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } + + cached + .subscriptions + .extend(request.room_subscriptions.clone()); + request + .room_subscriptions + .extend(cached.subscriptions.clone()); + + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); + + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); + + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or_else(|| cached.extensions.account_data.rooms.clone()); + + cached.extensions = request.extensions.clone(); + + cached.known_rooms.clone() + } + + pub fn update_sync_subscriptions( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, + subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + cached.subscriptions = subscriptions; + } + + pub fn update_sync_known_rooms( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, + new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64, + ) { + let mut cache = self.connections.lock().expect("locked"); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().expect("locked"); + drop(cache); + + for (roomid, lastsince) in cached + .known_rooms + .entry(list_id.clone()) + .or_default() + .iter_mut() + { + if !new_cached_rooms.contains(roomid) { + *lastsince = 0; + } + } + let list = cached.known_rooms.entry(list_id).or_default(); + for roomid in new_cached_rooms { + list.insert(roomid, globalsince); + } + } +} diff --git a/src/service/sync/watch.rs b/src/service/sync/watch.rs new file mode 100644 index 0000000000000000000000000000000000000000..3eb663c12b8c576daa4a1e62e65a92a4118961e3 --- /dev/null +++ b/src/service/sync/watch.rs @@ -0,0 +1,117 @@ +use conduit::{implement, trace, Result}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, StreamExt}; +use ruma::{DeviceId, UserId}; + +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xFF); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xFF); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed their key + // TODO: only send for user they share a room with + futures.push(self.db.todeviceid_events.watch_prefix(&userdeviceid_prefix)); + + futures.push(self.db.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.db.userroomid_invitestate.watch_prefix(&userid_prefix)); + futures.push(self.db.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.db + .userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push( + self.db + .userroomid_highlightcount + .watch_prefix(&userid_prefix), + ); + + // Events for rooms we are in + let rooms_joined = self.services.state_cache.rooms_joined(user_id); + + pin_mut!(rooms_joined); + while let Some(room_id) = rooms_joined.next().await { + let Ok(short_roomid) = self.services.short.get_shortroomid(room_id).await else { + continue; + }; + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xFF); + + // Key changes + futures.push(self.db.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.db + .roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + + // PDUs + let short_roomid = short_roomid.to_be_bytes().to_vec(); + futures.push(self.db.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + let typing_room_id = room_id.to_owned(); + let typing_wait_for_update = async move { + self.services.typing.wait_for_update(&typing_room_id).await; + }; + + futures.push(typing_wait_for_update.boxed()); + futures.push( + self.db + .readreceiptid_readreceipt + .watch_prefix(&roomid_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xFF]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.db + .roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.db.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push( + self.db + .userid_lastonetimekeyupdate + .watch_prefix(&userid_bytes), + ); + + // Server shutdown + let server_shutdown = async move { + while self.services.server.running() { + self.services.server.signal.subscribe().recv().await.ok(); + } + }; + + futures.push(server_shutdown.boxed()); + if !self.services.server.running() { + return Ok(()); + } + + // Wait until one of them finds something + trace!(futures = futures.len(), "watch started"); + futures.next().await; + trace!(futures = futures.len(), "watch finished"); + + Ok(()) +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs deleted file mode 100644 index 791b46f01f1bfc9a8a7d57a5aa5040f143cb1247..0000000000000000000000000000000000000000 --- a/src/service/transaction_ids/data.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::sync::Arc; - -use conduit::Result; -use database::{Database, Map}; -use ruma::{DeviceId, TransactionId, UserId}; - -pub struct Data { - userdevicetxnid_response: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - userdevicetxnid_response: db["userdevicetxnid_response"].clone(), - } - } - - pub(super) fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - self.userdevicetxnid_response.insert(&key, data)?; - - Ok(()) - } - - pub(super) fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result<Option<database::Handle<'_>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xFF); - key.extend_from_slice(txn_id.as_bytes()); - - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } -} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 78e6337f23aa3b87c30f4bfd769f6bf70a61b4c1..72f60adb1dc56aa9e55dd2fc2c4cab029d0ef75b 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,35 +1,45 @@ -mod data; - use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{implement, Result}; +use database::{Handle, Map}; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { - pub db: Data, + db: Data, +} + +struct Data { + userdevicetxnid_response: Arc<Map>, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data { + userdevicetxnid_response: args.db["userdevicetxnid_response"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()> { - self.db.add_txnid(user_id, device_id, txn_id, data) - } +#[implement(Service)] +pub fn add_txnid(&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8]) { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - pub fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result<Option<database::Handle<'_>>> { - self.db.existing_txnid(user_id, device_id, txn_id) - } + self.db.userdevicetxnid_response.insert(&key, data); +} + +// If there's no entry, this is a new transaction +#[implement(Service)] +pub async fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, +) -> Result<Handle<'_>> { + let key = (user_id, device_id, txn_id); + self.db.userdevicetxnid_response.qry(&key).await } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs deleted file mode 100644 index ce071da09425afc997e942bfac1be0186f88bfff..0000000000000000000000000000000000000000 --- a/src/service/uiaa/data.rs +++ /dev/null @@ -1,87 +0,0 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use conduit::{Error, Result}; -use database::{Database, Map}; -use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, -}; - -pub struct Data { - userdevicesessionid_uiaarequest: RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>, - userdevicesessionid_uiaainfo: Arc<Map>, -} - -impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { - Self { - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - userdevicesessionid_uiaainfo: db["userdevicesessionid_uiaainfo"].clone(), - } - } - - pub(super) fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - pub(super) fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option<CanonicalJsonValue> { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(ToOwned::to_owned) - } - - pub(super) fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xFF); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest(ErrorKind::forbidden(), "UIAA session does not exist."))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 6041bbd3479d5f22a7c8b603cdee85bf5dcb79cb..d2865d8823951880fa9fdb8c76148fada5b9bc41 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,25 +1,28 @@ -mod data; - -use std::sync::Arc; +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; -use conduit::{error, utils, utils::hash, Error, Result, Server}; -use data::Data; +use conduit::{ + err, error, implement, utils, + utils::{hash, string::EMPTY}, + Error, Result, +}; +use database::{Deserialized, Json, Map}; use ruma::{ api::client::{ error::ErrorKind, uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, }, - CanonicalJsonValue, DeviceId, UserId, + CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, }; use crate::{globals, users, Dep}; -pub const SESSION_ID_LENGTH: usize = 32; - pub struct Service { - server: Arc<Server>, + userdevicesessionid_uiaarequest: RwLock<RequestMap>, + db: Data, services: Services, - pub db: Data, } struct Services { @@ -27,148 +30,211 @@ struct Services { users: Dep<users::Service>, } +struct Data { + userdevicesessionid_uiaainfo: Arc<Map>, +} + +type RequestMap = BTreeMap<RequestKey, CanonicalJsonValue>; +type RequestKey = (OwnedUserId, OwnedDeviceId, String); + +pub const SESSION_ID_LENGTH: usize = 32; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - server: args.server.clone(), + userdevicesessionid_uiaarequest: RwLock::new(RequestMap::new()), + db: Data { + userdevicesessionid_uiaainfo: args.db["userdevicesessionid_uiaainfo"].clone(), + }, services: Services { globals: args.depend::<globals::Service>("globals"), users: args.depend::<users::Service>("users"), }, - db: Data::new(args.db), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.db.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why - * is it optional in ruma?) */ - json_body, - )?; - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth.session().map_or_else( - || Ok(uiaainfo.clone()), - |session| self.db.get_uiaa_session(user_id, device_id, session), - )?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } +/// Creates a new Uiaa session. Make sure the session token is unique. +#[implement(Service)] +pub fn create(&self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue) { + // TODO: better session error handling (why is uiaainfo.session optional in + // ruma?) + self.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + json_body, + ); + + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ); +} - match auth { - // Find out what the user completed - AuthData::Password(Password { - identifier, - password, - #[cfg(feature = "element_hacks")] - user, - .. - }) => { - #[cfg(feature = "element_hacks")] - let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { - username - } else if let Some(username) = user { - username - } else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - #[cfg(not(feature = "element_hacks"))] - let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier - else { - return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); - }; - - let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - - // Check if password is correct - if let Some(hash) = self.services.users.password_hash(&user_id)? { - let hash_matches = hash::verify_password(password, &hash).is_ok(); - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::forbidden(), - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } +#[implement(Service)] +pub async fn try_auth( + &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, +) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = if let Some(session) = auth.session() { + self.get_uiaa_session(user_id, device_id, session).await? + } else { + uiaainfo.clone() + }; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - }, - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + #[cfg(feature = "element_hacks")] + user, + .. + }) => { + #[cfg(feature = "element_hacks")] + let username = if let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier { + username + } else if let Some(username) = user { + username + } else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + #[cfg(not(feature = "element_hacks"))] + let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier + else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; + + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + + // Check if password is correct + if let Ok(hash) = self.services.users.password_hash(&user_id).await { + let hash_matches = hash::verify_password(password, &hash).is_ok(); + if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { kind: ErrorKind::forbidden(), - message: "Invalid registration token.".to_owned(), + message: "Invalid username or password.".to_owned(), }); return Ok((false, uiaainfo)); } - }, - AuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - }, - k => error!("type not supported: {:?}", k), - } + } - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + }, + AuthData::RegistrationToken(t) => { + if self + .services + .globals + .registration_token + .as_ref() + .is_some_and(|reg_token| t.token.trim() == reg_token) + { + uiaainfo.completed.push(AuthType::RegistrationToken); + } else { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::forbidden(), + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); } - // We didn't break, so this flow succeeded! - completed = true; - } + }, + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + }, + k => error!("type not supported: {:?}", k), + } - if !completed { - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } } + // We didn't break, so this flow succeeded! + completed = true; + } - // UIAA was successful! Remove this session and return true - self.db.update_uiaa_session( + if !completed { + self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) + Some(&uiaainfo), + ); + + return Ok((false, uiaainfo)); } - #[must_use] - pub fn get_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option<CanonicalJsonValue> { - self.db.get_uiaa_request(user_id, device_id, session) + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + ); + + Ok((true, uiaainfo)) +} + +#[implement(Service)] +fn set_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue) { + let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); + self.userdevicesessionid_uiaarequest + .write() + .expect("locked for writing") + .insert(key, request.to_owned()); +} + +#[implement(Service)] +pub fn get_uiaa_request( + &self, user_id: &UserId, device_id: Option<&DeviceId>, session: &str, +) -> Option<CanonicalJsonValue> { + let key = ( + user_id.to_owned(), + device_id.unwrap_or_else(|| EMPTY.into()).to_owned(), + session.to_owned(), + ); + + self.userdevicesessionid_uiaarequest + .read() + .expect("locked for reading") + .get(&key) + .cloned() +} + +#[implement(Service)] +fn update_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>) { + let key = (user_id, device_id, session); + + if let Some(uiaainfo) = uiaainfo { + self.db + .userdevicesessionid_uiaainfo + .put(key, Json(uiaainfo)); + } else { + self.db.userdevicesessionid_uiaainfo.del(key); } } + +#[implement(Service)] +async fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> { + let key = (user_id, device_id, session); + self.db + .userdevicesessionid_uiaainfo + .qry(&key) + .await + .deserialized() + .map_err(|_| err!(Request(Forbidden("UIAA session does not exist.")))) +} diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 3c69b243002f5b6629762c111e49525c90617c19..adc85fe60a109ed1edf3d0dd79d2ab8a2c31f9f0 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -1,19 +1,22 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, err, info, utils, warn, Error, Result}; -use database::Map; +use conduit::{debug, info, warn, Result}; +use database::{Deserialized, Map}; use ruma::events::room::message::RoomMessageEventContent; use serde::Deserialize; -use tokio::{sync::Notify, time::interval}; +use tokio::{ + sync::Notify, + time::{interval, MissedTickBehavior}, +}; use crate::{admin, client, globals, Dep}; pub struct Service { - services: Services, - db: Arc<Map>, - interrupt: Notify, interval: Duration, + interrupt: Notify, + db: Arc<Map>, + services: Services, } struct Services { @@ -22,12 +25,12 @@ struct Services { globals: Dep<globals::Service>, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponse { updates: Vec<CheckForUpdatesResponseEntry>, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct CheckForUpdatesResponseEntry { id: u64, date: String, @@ -42,33 +45,38 @@ struct CheckForUpdatesResponseEntry { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + interrupt: Notify::new(), + db: args.db["global"].clone(), services: Services { globals: args.depend::<globals::Service>("globals"), admin: args.depend::<admin::Service>("admin"), client: args.depend::<client::Service>("client"), }, - db: args.db["global"].clone(), - interrupt: Notify::new(), - interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), })) } + #[tracing::instrument(skip_all, name = "updates", level = "trace")] async fn worker(self: Arc<Self>) -> Result<()> { if !self.services.globals.allow_check_for_updates() { debug!("Disabling update check"); return Ok(()); } + let mut i = interval(self.interval); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { tokio::select! { - () = self.interrupt.notified() => return Ok(()), + () = self.interrupt.notified() => break, _ = i.tick() => (), } - if let Err(e) = self.handle_updates().await { + if let Err(e) = self.check().await { warn!(%e, "Failed to check for updates"); } } + + Ok(()) } fn interrupt(&self) { self.interrupt.notify_waiters(); } @@ -77,52 +85,49 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip_all)] - async fn handle_updates(&self) -> Result<()> { + #[tracing::instrument(skip_all, level = "trace")] + async fn check(&self) -> Result<()> { let response = self .services .client .default .get(CHECK_FOR_UPDATES_URL) .send() + .await? + .text() .await?; - let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?) - .map_err(|e| err!("Bad check for updates response: {e}"))?; - - let mut last_update_id = self.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > self.last_check_for_updates_id()? { - info!("{:#}", update.message); - self.services - .admin - .send_message(RoomMessageEventContent::text_markdown(format!( - "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", - update.date, update.message - ))) - .await; + let response = serde_json::from_str::<CheckForUpdatesResponse>(&response)?; + for update in &response.updates { + if update.id > self.last_check_for_updates_id().await { + self.handle(update).await; + self.update_check_for_updates_id(update.id); } } - self.update_check_for_updates_id(last_update_id)?; Ok(()) } - #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.db - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) + async fn handle(&self, update: &CheckForUpdatesResponseEntry) { + info!("{} {:#}", update.date, update.message); + self.services + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await + .ok(); } - pub fn last_check_for_updates_id(&self) -> Result<u64> { + #[inline] + pub fn update_check_for_updates_id(&self, id: u64) { self.db.raw_put(LAST_CHECK_FOR_UPDATES_COUNT, id); } + + pub async fn last_check_for_updates_id(&self) -> u64 { self.db - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) + .get(LAST_CHECK_FOR_UPDATES_COUNT) + .await + .deserialized() + .unwrap_or(0_u64) } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs deleted file mode 100644 index 70ff12e3f1d45729d8981de87f158a50a3356b24..0000000000000000000000000000000000000000 --- a/src/service/users/data.rs +++ /dev/null @@ -1,1098 +0,0 @@ -use std::{collections::BTreeMap, mem::size_of, sync::Arc}; - -use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; -use database::Map; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - uint, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, - OwnedMxcUri, OwnedUserId, UInt, UserId, -}; - -use crate::{globals, rooms, users::clean_signatures, Dep}; - -pub struct Data { - keychangeid_userid: Arc<Map>, - keyid_key: Arc<Map>, - onetimekeyid_onetimekeys: Arc<Map>, - openidtoken_expiresatuserid: Arc<Map>, - todeviceid_events: Arc<Map>, - token_userdeviceid: Arc<Map>, - userdeviceid_metadata: Arc<Map>, - userdeviceid_token: Arc<Map>, - userfilterid_filter: Arc<Map>, - userid_avatarurl: Arc<Map>, - userid_blurhash: Arc<Map>, - userid_devicelistversion: Arc<Map>, - userid_displayname: Arc<Map>, - userid_lastonetimekeyupdate: Arc<Map>, - userid_masterkeyid: Arc<Map>, - userid_password: Arc<Map>, - userid_selfsigningkeyid: Arc<Map>, - userid_usersigningkeyid: Arc<Map>, - useridprofilekey_value: Arc<Map>, - services: Services, -} - -struct Services { - server: Arc<Server>, - globals: Dep<globals::Service>, - state_cache: Dep<rooms::state_cache::Service>, - state_accessor: Dep<rooms::state_accessor::Service>, -} - -impl Data { - pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; - Self { - keychangeid_userid: db["keychangeid_userid"].clone(), - keyid_key: db["keyid_key"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - token_userdeviceid: db["token_userdeviceid"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - userid_avatarurl: db["userid_avatarurl"].clone(), - userid_blurhash: db["userid_blurhash"].clone(), - userid_devicelistversion: db["userid_devicelistversion"].clone(), - userid_displayname: db["userid_displayname"].clone(), - userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - userid_masterkeyid: db["userid_masterkeyid"].clone(), - userid_password: db["userid_password"].clone(), - userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), - userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - useridprofilekey_value: db["useridprofilekey_value"].clone(), - services: Services { - server: args.server.clone(), - globals: args.depend::<globals::Service>("globals"), - state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), - state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), - }, - } - } - - /// Check if a user has an account on this homeserver. - #[inline] - pub(super) fn exists(&self, user_id: &UserId) -> Result<bool> { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } - - /// Check if account is deactivated - pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - #[inline] - pub(super) fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) } - - /// Find out which user an access token belongs to. - pub(super) fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = parts - .next() - .ok_or_else(|| err!(Database("User ID in token_userdeviceid is invalid.")))?; - let device_bytes = parts - .next() - .ok_or_else(|| err!(Database("Device ID in token_userdeviceid is invalid.")))?; - - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid. {e}")))?, - utils::string_from_bytes(device_bytes) - .map_err(|e| err!(Database("Device ID in token_userdeviceid is invalid. {e}")))?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("User ID in userid_password is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in userid_password is invalid. {e}"))) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - pub(super) fn list_local_users(&self) -> Result<Vec<String>> { - let users: Vec<String> = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - pub(super) fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::hash::password(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - pub(super) fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database("Displayname in db is invalid. {e}")))?, - )) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the `avatar_url` of a user. - pub(super) fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes) - .map_err(|e| err!(Database(warn!("Avatar URL in db is invalid: {e}"))))?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - pub(super) fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Avatar URL in db is invalid. {e}"))) - }) - .transpose() - } - - /// Gets a specific user profile key - pub(super) fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<Option<serde_json::Value>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - self.useridprofilekey_value - .get(&key)? - .map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes).unwrap()))) - } - - /// Gets all the user's profile keys and values in an iterator - pub(super) fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box<dyn Iterator<Item = Result<(String, serde_json::Value)>> + 'a + Send> { - let prefix = user_id.as_bytes().to_vec(); - - Box::new( - self.useridprofilekey_value - .scan_prefix(prefix) - .map(|(key, value)| { - let profile_key_name = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("Profile key in db is invalid")))?, - ) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - let profile_key_value = serde_json::from_slice(&value) - .map_err(|e| err!(Database("Profile key in db is invalid. {e}")))?; - - Ok((profile_key_name, profile_key_value)) - }), - ) - } - - /// Sets a new profile key value, removes the key if value is None - pub(super) fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(profile_key.as_bytes()); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(value) = profile_key_value { - let value = serde_json::to_vec(&value).unwrap(); - - self.useridprofilekey_value.insert(&key, &value) - } else { - self.useridprofilekey_value.remove(&key) - } - } - - /// Get the timezone of a user. - pub(super) fn timezone(&self, user_id: &UserId) -> Result<Option<String>> { - // first check the unstable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - let value = self - .useridprofilekey_value - .get(&key)? - .map(|bytes| utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}")))) - .transpose() - .unwrap(); - - // TODO: transparently migrate unstable key usage to the stable key once MSC4133 - // and MSC4175 are stable, likely a remove/insert in this block - if value.is_none() || value.as_ref().is_some_and(String::is_empty) { - // check the stable prefix - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"m.tz"); - - return self - .useridprofilekey_value - .get(&key)? - .map(|bytes| { - utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Timezone in db is invalid. {e}"))) - }) - .transpose(); - } - - Ok(value) - } - - /// Sets a new timezone or removes it if timezone is None. - pub(super) fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(b"us.cloke.msc4175.tz"); - - // TODO: insert to the stable MSC4175 key when it's stable - if let Some(timezone) = timezone { - self.useridprofilekey_value - .insert(&key, timezone.as_bytes())?; - } else { - self.useridprofilekey_value.remove(&key)?; - } - - Ok(()) - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - pub(super) fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>, - client_ip: Option<String>, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert - // though... - if !self.exists(user_id)? { - warn!("Called create_device for non-existent user {} in database", user_id); - return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xFF); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - pub(super) fn all_device_ids<'a>( - &'a self, user_id: &UserId, - ) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("UserDevice ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("Device ID in userdeviceid_metadata is invalid. {e}")))? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - pub(super) fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw<OneTimeKey>, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - return Err!(Database(error!( - "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." - ))); - } - - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Count in roomid_lastroomactiveupdate is invalid. {e}"))) - }) - } - - pub(super) fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))?, - ) - .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))?, - serde_json::from_slice(&value).map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))?, - )) - }) - .transpose() - } - - pub(super) fn count_one_time_keys( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::<OwnedDeviceKeyId>( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKey ID in db is invalid.")))?, - ) - .map_err(|e| err!(Database("DeviceKeyId in db is invalid. {e}")))? - .algorithm(), - ) - }) { - let count: &mut UInt = counts.entry(algorithm?).or_default(); - *count = count.saturating_add(uint!(1)); - } - - Ok(counts) - } - - pub(super) fn add_device_keys( - &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - pub(super) fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, - user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key - .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - pub(super) fn sign_key( - &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = serde_json::from_slice( - &self - .keyid_key - .get(&key)? - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, - ) - .map_err(|e| err!(Database("key in keyid_key is invalid. {e}")))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? - .as_object_mut() - .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - pub(super) fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option<u64>, - ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from.saturating_add(1)).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?, - ) - .map_err(|e| err!(Database("User ID in devicekeychangeid_userid is invalid. {e}"))) - }), - ) - } - - pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = self.services.globals.next_count()?.to_be_bytes(); - for room_id in self - .services - .state_cache - .rooms_joined(user_id) - .filter_map(Result::ok) - { - // Don't send key updates to unencrypted rooms - if self - .services - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes).map_err(|e| err!(Database("DeviceKeys in db are invalid. {e}")))?, - )) - }) - } - - pub(super) fn parse_master_key( - user_id: &UserId, master_key: &Raw<CrossSigningKey>, - ) -> Result<(Vec<u8>, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - pub(super) fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?; - clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), - ))) - }) - } - - pub(super) fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) - } - - pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?, - )) - }) - }) - } - - pub(super) fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xFF); - key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - pub(super) fn get_to_device_events( - &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result<Vec<Raw<AnyToDeviceEvent>>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|e| err!(Database("Event in todeviceid_events is invalid. {e}")))?, - ); - } - - Ok(events) - } - - pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::<u64>())..key.len()]) - .map_err(|e| err!(Database("ToDeviceId has invalid count bytes. {e}")))?, - )) - }) - .filter_map(Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert - // either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ - metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xFF); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|e| err!(Database("Invalid devicelistversion in db. {e}"))) - .map(Some) - }) - } - - pub(super) fn all_devices_metadata<'a>( - &'a self, user_id: &UserId, - ) -> Box<dyn Iterator<Item = Result<Device>> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::<Device>(&bytes) - .map_err(|e| err!(Database("Device in userdeviceid_metadata is invalid. {e}"))) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter - .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; - - Ok(filter_id) - } - - pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|e| err!(Database("Invalid filter event in db. {e}"))) - } else { - Ok(None) - } - } - - /// Creates an OpenID token, which can be used to prove that a user has - /// access to an account (primarily for integrations) - pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> { - use std::num::Saturating as Sat; - - let expires_in = self.services.server.config.openid_token_ttl; - let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); - - let mut value = expires_at.0.to_be_bytes().to_vec(); - value.extend_from_slice(user_id.as_bytes()); - - self.openidtoken_expiresatuserid - .insert(token.as_bytes(), value.as_slice())?; - - Ok(expires_in) - } - - /// Find out which user an OpenID access token belongs to. - pub(super) fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> { - let Some(value) = self.openidtoken_expiresatuserid.get(token.as_bytes())? else { - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is unrecognised")); - }; - - let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); - - let expires_at = u64::from_be_bytes( - expires_at_bytes - .try_into() - .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, - ); - - if expires_at < utils::millis_since_unix_epoch() { - debug_info!("OpenID token is expired, removing"); - self.openidtoken_expiresatuserid.remove(token.as_bytes())?; - - return Err(Error::BadRequest(ErrorKind::Unauthorized, "OpenID token is expired")); - } - - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?, - ) - .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) - } -} - -/// Will only return with Some(username) if the password was not empty and the -/// username could be successfully parsed. -/// If `utils::string_from_bytes`(...) returns an error that username will be -/// skipped and the error will be logged. -pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); - None - }, - } - } -} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 80897b5ff3a913bb6f83b65b36971ac32e1079e7..1f8c56dfaff3698a38a2a910f0aa7944564a8d88 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,552 +1,931 @@ -mod data; +use std::{collections::BTreeMap, mem, mem::size_of, sync::Arc}; -use std::{ - collections::{BTreeMap, BTreeSet}, - mem, - sync::{Arc, Mutex, Mutex as StdMutex}, +use conduit::{ + debug_warn, err, utils, + utils::{stream::TryIgnore, string::Unquoted, ReadyExt}, + Err, Error, Result, Server, }; - -use conduit::{Error, Result}; +use database::{Deserialized, Ignore, Interfix, Json, Map}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - device::Device, - filter::FilterDefinition, - sync::sync_events::{ - self, - v4::{ExtensionsConfig, SyncRequestList}, - }, - }, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, + events::{ignored_user_list::IgnoredUserListEvent, AnyToDeviceEvent, GlobalAccountDataEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedRoomId, OwnedUserId, - UInt, UserId, + DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedDeviceId, + OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, }; +use serde_json::json; -use self::data::Data; -use crate::{admin, rooms, Dep}; +use crate::{account_data, admin, globals, rooms, Dep}; pub struct Service { - connections: DbConnections, - pub db: Data, services: Services, + db: Data, } struct Services { + server: Arc<Server>, + account_data: Dep<account_data::Service>, admin: Dep<admin::Service>, + globals: Dep<globals::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, state_cache: Dep<rooms::state_cache::Service>, } +struct Data { + keychangeid_userid: Arc<Map>, + keyid_key: Arc<Map>, + onetimekeyid_onetimekeys: Arc<Map>, + openidtoken_expiresatuserid: Arc<Map>, + todeviceid_events: Arc<Map>, + token_userdeviceid: Arc<Map>, + userdeviceid_metadata: Arc<Map>, + userdeviceid_token: Arc<Map>, + userfilterid_filter: Arc<Map>, + userid_avatarurl: Arc<Map>, + userid_blurhash: Arc<Map>, + userid_devicelistversion: Arc<Map>, + userid_displayname: Arc<Map>, + userid_lastonetimekeyupdate: Arc<Map>, + userid_masterkeyid: Arc<Map>, + userid_password: Arc<Map>, + userid_selfsigningkeyid: Arc<Map>, + userid_usersigningkeyid: Arc<Map>, + useridprofilekey_value: Arc<Map>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - connections: StdMutex::new(BTreeMap::new()), - db: Data::new(&args), services: Services { + server: args.server.clone(), + account_data: args.depend::<account_data::Service>("account_data"), admin: args.depend::<admin::Service>("admin"), + globals: args.depend::<globals::Service>("globals"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), }, + db: Data { + keychangeid_userid: args.db["keychangeid_userid"].clone(), + keyid_key: args.db["keyid_key"].clone(), + onetimekeyid_onetimekeys: args.db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: args.db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: args.db["todeviceid_events"].clone(), + token_userdeviceid: args.db["token_userdeviceid"].clone(), + userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(), + userdeviceid_token: args.db["userdeviceid_token"].clone(), + userfilterid_filter: args.db["userfilterid_filter"].clone(), + userid_avatarurl: args.db["userid_avatarurl"].clone(), + userid_blurhash: args.db["userid_blurhash"].clone(), + userid_devicelistversion: args.db["userid_devicelistversion"].clone(), + userid_displayname: args.db["userid_displayname"].clone(), + userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), + userid_masterkeyid: args.db["userid_masterkeyid"].clone(), + userid_password: args.db["userid_password"].clone(), + userid_selfsigningkeyid: args.db["userid_selfsigningkeyid"].clone(), + userid_usersigningkeyid: args.db["userid_usersigningkeyid"].clone(), + useridprofilekey_value: args.db["useridprofilekey_value"].clone(), + }, })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -type DbConnections = Mutex<BTreeMap<DbConnectionsKey, DbConnectionsVal>>; -type DbConnectionsKey = (OwnedUserId, OwnedDeviceId, String); -type DbConnectionsVal = Arc<Mutex<SlidingSyncCache>>; - -struct SlidingSyncCache { - lists: BTreeMap<String, SyncRequestList>, - subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, - known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, u64>>, // For every room, the roomsince number - extensions: ExtensionsConfig, -} - impl Service { - /// Check if a user has an account on this homeserver. - #[inline] - pub fn exists(&self, user_id: &UserId) -> Result<bool> { self.db.exists(user_id) } - - pub fn remembered(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) -> bool { - self.connections - .lock() - .unwrap() - .contains_key(&(user_id, device_id, conn_id)) - } - - pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); - } - - pub fn update_sync_request_with_cache( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, - ) -> BTreeMap<String, BTreeMap<OwnedRoomId, u64>> { - let Some(conn_id) = request.conn_id.clone() else { - return BTreeMap::new(); - }; - - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (list_id, list) in &mut request.lists { - if let Some(cached_list) = cached.lists.get(list_id) { - if list.sort.is_empty() { - list.sort.clone_from(&cached_list.sort); - }; - if list.room_details.required_state.is_empty() { - list.room_details - .required_state - .clone_from(&cached_list.room_details.required_state); - }; - list.room_details.timeline_limit = list - .room_details - .timeline_limit - .or(cached_list.room_details.timeline_limit); - list.include_old_rooms = list - .include_old_rooms - .clone() - .or_else(|| cached_list.include_old_rooms.clone()); - match (&mut list.filters, cached_list.filters.clone()) { - (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); - if list_filters.spaces.is_empty() { - list_filters.spaces = cached_filters.spaces; - } - list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); - list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); - if list_filters.room_types.is_empty() { - list_filters.room_types = cached_filters.room_types; - } - if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; - } - list_filters.room_name_like = list_filters - .room_name_like - .clone() - .or(cached_filters.room_name_like); - if list_filters.tags.is_empty() { - list_filters.tags = cached_filters.tags; - } - if list_filters.not_tags.is_empty() { - list_filters.not_tags = cached_filters.not_tags; - } - }, - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (..) => {}, - } - if list.bump_event_types.is_empty() { - list.bump_event_types - .clone_from(&cached_list.bump_event_types); - }; - } - cached.lists.insert(list_id.clone(), list.clone()); - } - - cached - .subscriptions - .extend(request.room_subscriptions.clone()); - request - .room_subscriptions - .extend(cached.subscriptions.clone()); - - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); - - request.extensions.to_device.enabled = request - .extensions - .to_device - .enabled - .or(cached.extensions.to_device.enabled); - - request.extensions.account_data.enabled = request - .extensions - .account_data - .enabled - .or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = request - .extensions + /// Returns true/false based on whether the recipient/receiving user has + /// blocked the sender + pub async fn user_is_ignored(&self, sender_user: &UserId, recipient_user: &UserId) -> bool { + self.services .account_data - .lists - .clone() - .or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = request - .extensions - .account_data - .rooms - .clone() - .or_else(|| cached.extensions.account_data.rooms.clone()); + .get_global(recipient_user, GlobalAccountDataEventType::IgnoredUserList) + .await + .map_or(false, |ignored: IgnoredUserListEvent| { + ignored + .content + .ignored_users + .keys() + .any(|blocked_user| blocked_user == sender_user) + }) + } - cached.extensions = request.extensions.clone(); + /// Check if a user is an admin + #[inline] + pub async fn is_admin(&self, user_id: &UserId) -> bool { self.services.admin.user_is_admin(user_id).await } - cached.known_rooms.clone() + /// Create a new user account on this homeserver. + #[inline] + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.set_password(user_id, password) } - pub fn update_sync_subscriptions( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, - subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + /// Deactivate account + pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + self.all_device_ids(user_id) + .for_each(|device_id| self.remove_device(user_id, device_id)) + .await; - cached.subscriptions = subscriptions; - } + // Set the password to "" to indicate a deactivated account. Hashes will never + // result in an empty string, so the user will not be able to log in again. + // Systems like changing the password without logging in should check if the + // account is deactivated. + self.set_password(user_id, None)?; - pub fn update_sync_known_rooms( - &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, - new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); - - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() - { - if !new_cached_rooms.contains(roomid) { - *lastsince = 0; - } - } - let list = cached.known_rooms.entry(list_id).or_default(); - for roomid in new_cached_rooms { - list.insert(roomid, globalsince); - } + // TODO: Unhook 3PID + Ok(()) } - /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { self.db.is_deactivated(user_id) } + /// Check if a user has an account on this homeserver. + #[inline] + pub async fn exists(&self, user_id: &UserId) -> bool { self.db.userid_password.get(user_id).await.is_ok() } - /// Check if a user is an admin - pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { - if let Some(admin_room_id) = self.services.admin.get_admin_room()? { - self.services.state_cache.is_joined(user_id, &admin_room_id) - } else { - Ok(false) - } + /// Check if account is deactivated + pub async fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { + self.db + .userid_password + .get(user_id) + .map_ok(|val| val.is_empty()) + .map_err(|_| err!(Request(NotFound("User does not exist.")))) + .await } - /// Create a new user account on this homeserver. - #[inline] - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password)?; - Ok(()) + /// Check if account is active, infallible + pub async fn is_active(&self, user_id: &UserId) -> bool { !self.is_deactivated(user_id).await.unwrap_or(true) } + + /// Check if account is active, infallible + pub async fn is_active_local(&self, user_id: &UserId) -> bool { + self.services.globals.user_is_local(user_id) && self.is_active(user_id).await } /// Returns the number of users registered on this server. #[inline] - pub fn count(&self) -> Result<usize> { self.db.count() } + pub async fn count(&self) -> usize { self.db.userid_password.count().await } /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { - self.db.find_from_token(token) + pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> { + self.db.token_userdeviceid.get(token).await.deserialized() } + /// Returns an iterator over all users on this homeserver (offered for + /// compatibility) + #[allow(clippy::iter_without_into_iter, clippy::iter_not_returning_iterator)] + pub fn iter(&self) -> impl Stream<Item = OwnedUserId> + Send + '_ { self.stream().map(ToOwned::to_owned) } + /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.iter() } + pub fn stream(&self) -> impl Stream<Item = &UserId> + Send { self.db.userid_password.keys().ignore_err() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is /// greater then zero. - pub fn list_local_users(&self) -> Result<Vec<String>> { self.db.list_local_users() } + pub fn list_local_users(&self) -> impl Stream<Item = &UserId> + Send + '_ { + self.db + .userid_password + .stream() + .ignore_err() + .ready_filter_map(|(u, p): (&UserId, &[u8])| (!p.is_empty()).then_some(u)) + } /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { self.db.password_hash(user_id) } + pub async fn password_hash(&self, user_id: &UserId) -> Result<String> { + self.db.userid_password.get(user_id).await.deserialized() + } /// Hash and set the user's password to the Argon2 hash - #[inline] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) + password + .map(utils::hash::password) + .transpose() + .map_err(|e| err!(Request(InvalidParam("Password does not meet the requirements: {e}"))))? + .map_or_else( + || self.db.userid_password.insert(user_id, b""), + |hash| self.db.userid_password.insert(user_id, hash), + ); + + Ok(()) } /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { self.db.displayname(user_id) } + pub async fn displayname(&self, user_id: &UserId) -> Result<String> { + self.db.userid_displayname.get(user_id).await.deserialized() + } /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. - pub async fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { - self.db.set_displayname(user_id, displayname) + pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) { + if let Some(displayname) = displayname { + self.db.userid_displayname.insert(user_id, displayname); + } else { + self.db.userid_displayname.remove(user_id); + } } - /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> { self.db.avatar_url(user_id) } + /// Get the `avatar_url` of a user. + pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> { + self.db.userid_avatarurl.get(user_id).await.deserialized() + } /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { - self.db.set_avatar_url(user_id, avatar_url) + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) { + if let Some(avatar_url) = avatar_url { + self.db.userid_avatarurl.insert(user_id, &avatar_url); + } else { + self.db.userid_avatarurl.remove(user_id); + } } /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { self.db.blurhash(user_id) } - - pub fn timezone(&self, user_id: &UserId) -> Result<Option<String>> { self.db.timezone(user_id) } - - /// Gets a specific user profile key - pub fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<Option<serde_json::Value>> { - self.db.profile_key(user_id, profile_key) - } - - /// Gets all the user's profile keys and values in an iterator - pub fn all_profile_keys<'a>( - &'a self, user_id: &UserId, - ) -> Box<dyn Iterator<Item = Result<(String, serde_json::Value)>> + 'a + Send> { - self.db.all_profile_keys(user_id) - } - - /// Sets a new profile key value, removes the key if value is None - pub fn set_profile_key( - &self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>, - ) -> Result<()> { - self.db - .set_profile_key(user_id, profile_key, profile_key_value) + pub async fn blurhash(&self, user_id: &UserId) -> Result<String> { + self.db.userid_blurhash.get(user_id).await.deserialized() } - /// Sets a new tz or removes it if tz is None. - pub async fn set_timezone(&self, user_id: &UserId, tz: Option<String>) -> Result<()> { - self.db.set_timezone(user_id, tz) - } - - /// Sets a new blurhash or removes it if blurhash is None. - pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { - self.db.set_blurhash(user_id, blurhash) + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) { + if let Some(blurhash) = blurhash { + self.db.userid_blurhash.insert(user_id, blurhash); + } else { + self.db.userid_blurhash.remove(user_id); + } } /// Adds a new device to a user. - pub fn create_device( + pub async fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>, client_ip: Option<String>, ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name, client_ip) + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id).await { + return Err!(Request(InvalidParam(error!("Called create_device for non-existent {user_id}")))); + } + + let key = (user_id, device_id); + let val = Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }; + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + self.db.userdeviceid_metadata.put(key, Json(val)); + self.set_token(user_id, device_id, token).await } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) + pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { + let userdeviceid = (user_id, device_id); + + // Remove tokens + if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await { + self.db.userdeviceid_token.del(userdeviceid); + self.db.token_userdeviceid.remove(&old_token); + } + + // Remove todevice events + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.todeviceid_events.remove(key)) + .await; + + // TODO: Remove onetimekeys + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.del(userdeviceid); } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedDeviceId>> + 'a { - self.db.all_device_ids(user_id) + pub fn all_device_ids<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = &DeviceId> + Send + 'a { + let prefix = (user_id, Interfix); + self.db + .userdeviceid_metadata + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, device_id): (Ignore, &DeviceId)| device_id) + } + + pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> { + let key = (user_id, device_id); + self.db.userdeviceid_token.qry(&key).await.deserialized() } /// Replaces the access token of one device. - #[inline] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - self.db.set_token(user_id, device_id, token) + pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let key = (user_id, device_id); + // should not be None, but we shouldn't assert either lol... + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + // Remove old token + if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await { + self.db.token_userdeviceid.remove(&old_token); + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + self.db.userdeviceid_token.put_raw(key, token); + self.db.token_userdeviceid.raw_put(token, key); + + Ok(()) } - pub fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + pub async fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, one_time_key_value: &Raw<OneTimeKey>, - ) -> Result<()> { + ) -> Result { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self.db.userdeviceid_metadata.qry(&key).await.is_err() { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) - } + .onetimekeyid_onetimekeys + .raw_put(key, Json(one_time_key_value)); - // TODO: use this ? - #[allow(dead_code)] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { - self.db.last_one_time_keys_update(user_id) - } + let count = self.services.globals.next_count().unwrap(); + self.db.userid_lastonetimekeyupdate.raw_put(user_id, count); - pub fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) + Ok(()) } - pub fn count_one_time_keys( + pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { + self.db + .userid_lastonetimekeyupdate + .get(user_id) + .await + .deserialized() + .unwrap_or(0) + } + + pub async fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm, + ) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> { + let count = self.services.globals.next_count()?.to_be_bytes(); + self.db.userid_lastonetimekeyupdate.insert(user_id, count); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + let one_time_key = self + .db + .onetimekeyid_onetimekeys + .raw_stream_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + self.db.onetimekeyid_onetimekeys.remove(key); + + let key = key + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) + .unwrap(); + + let key = serde_json::from_slice(key) + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) + .unwrap(); + + let val = serde_json::from_slice(val) + .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) + .unwrap(); + + (key, val) + }) + .next() + .await; + + one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) + } + + pub async fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> { - self.db.count_one_time_keys(user_id, device_id) + ) -> BTreeMap<OneTimeKeyAlgorithm, UInt> { + type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); + + let mut algorithm_counts = BTreeMap::<OneTimeKeyAlgorithm, _>::new(); + let query = (user_id, device_id); + self.db + .onetimekeyid_onetimekeys + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { + let one_time_key_id: &OneTimeKeyId = device_key_id + .as_str() + .try_into() + .expect("Invalid DeviceKeyID in database"); + + let count: &mut UInt = algorithm_counts + .entry(one_time_key_id.algorithm()) + .or_default(); + + *count = count.saturating_add(1_u32.into()); + }) + .await; + + algorithm_counts } - pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) + pub async fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) { + let key = (user_id, device_id); + + self.db.keyid_key.put(key, Json(device_keys)); + self.mark_device_key_update(user_id).await; } - pub fn add_cross_signing_keys( + pub async fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool, ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = parse_master_key(user_id, master_key)?; + self.db - .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) + .keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes()); + + self.db + .userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key); + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); + + self.db + .userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key); + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; + + if user_signing_key_ids.next().is_some() { + return Err!(Request(InvalidParam("User signing key contained more than one key."))); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes()); + + self.db + .userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key); + } + + if notify { + self.mark_device_key_update(user_id).await; + } + + Ok(()) } - pub fn sign_key( + pub async fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) + let key = (target_id, key_id); + + let mut cross_signing_key: serde_json::Value = self + .db + .keyid_key + .qry(&key) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key."))))? + .deserialized() + .map_err(|e| err!(Database("key in keyid_key is invalid. {e:?}")))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? + .as_object_mut() + .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? + .insert(signature.0, signature.1.into()); + + let key = (target_id, key_id); + self.db.keyid_key.put(key, Json(cross_signing_key)); + + self.mark_device_key_update(target_id).await; + + Ok(()) } + #[inline] pub fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option<u64>, - ) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { - self.db.keys_changed(user_or_room_id, from, to) + &'a self, user_id: &'a UserId, from: u64, to: Option<u64>, + ) -> impl Stream<Item = &UserId> + Send + 'a { + self.keys_changed_user_or_room(user_id.as_str(), from, to) + .map(|(user_id, ..)| user_id) } #[inline] - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } + pub fn room_keys_changed<'a>( + &'a self, room_id: &'a RoomId, from: u64, to: Option<u64>, + ) -> impl Stream<Item = (&UserId, u64)> + Send + 'a { + self.keys_changed_user_or_room(room_id.as_str(), from, to) + } + + fn keys_changed_user_or_room<'a>( + &'a self, user_or_room_id: &'a str, from: u64, to: Option<u64>, + ) -> impl Stream<Item = (&UserId, u64)> + Send + 'a { + type KeyVal<'a> = ((&'a str, u64), &'a UserId); - pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> { - self.db.get_device_keys(user_id, device_id) + let to = to.unwrap_or(u64::MAX); + let start = (user_or_room_id, from.saturating_add(1)); + self.db + .keychangeid_userid + .stream_from(&start) + .ignore_err() + .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| *prefix == user_or_room_id && *count <= to) + .map(|((_, count), user_id): KeyVal<'_>| (user_id, count)) } - #[inline] - pub fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, - ) -> Result<(Vec<u8>, CrossSigningKey)> { - Data::parse_master_key(user_id, master_key) + pub async fn mark_device_key_update(&self, user_id: &UserId) { + let count = self.services.globals.next_count().unwrap(); + + self.services + .state_cache + .rooms_joined(user_id) + // Don't send key updates to unencrypted rooms + .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) + .ready_for_each(|room_id| { + let key = (room_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); + }) + .await; + + let key = (user_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); } - #[inline] - pub fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) + pub async fn get_device_keys<'a>(&'a self, user_id: &'a UserId, device_id: &DeviceId) -> Result<Raw<DeviceKeys>> { + let key_id = (user_id, device_id); + self.db.keyid_key.qry(&key_id).await.deserialized() } - pub fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) + pub async fn get_key<F>( + &self, key_id: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result<Raw<CrossSigningKey>> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?; + + let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; + let raw_value = serde_json::value::to_raw_value(&cleaned)?; + Ok(Raw::from_json(raw_value)) } - pub fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result<Option<Raw<CrossSigningKey>>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) + pub async fn get_master_key<F>( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result<Raw<CrossSigningKey>> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_masterkeyid.get(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { - self.db.get_user_signing_key(user_id) + pub async fn get_self_signing_key<F>( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, + ) -> Result<Raw<CrossSigningKey>> + where + F: Fn(&UserId) -> bool + Send + Sync, + { + let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await } - pub fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()> { + pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> { self.db - .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) + .userid_usersigningkeyid + .get(user_id) + .and_then(|key_id| self.db.keyid_key.get(&*key_id)) + .await + .deserialized() } - pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> { - self.db.get_to_device_events(user_id, device_id) + pub async fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) { + let count = self.services.globals.next_count().unwrap(); + + let key = (target_user_id, target_device_id, count); + self.db.todeviceid_events.put( + key, + Json(json!({ + "type": event_type, + "sender": sender, + "content": content, + })), + ); } - pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { - self.db.remove_to_device_events(user_id, device_id, until) + pub fn get_to_device_events<'a>( + &'a self, user_id: &'a UserId, device_id: &'a DeviceId, + ) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a { + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, val): (Ignore, Raw<AnyToDeviceEvent>)| val) } - pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) + pub async fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + self.db + .todeviceid_events + .rev_raw_keys_from(&last) // this includes last + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) + .map(|key| { + let len = key.len(); + let start = len.saturating_sub(size_of::<u64>()); + let count = utils::u64_from_u8(&key[start..len]); + (key, count) + }) + .ready_take_while(move |(_, count)| *count <= until) + .ready_for_each(|(key, _)| self.db.todeviceid_events.remove(&key)) + .boxed() + .await; + } + + pub async fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + let key = (user_id, device_id); + self.db.userdeviceid_metadata.put(key, Json(device)); + + Ok(()) } /// Get device metadata. - pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> { - self.db.get_device_metadata(user_id, device_id) + pub async fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Device> { + self.db + .userdeviceid_metadata + .qry(&(user_id, device_id)) + .await + .deserialized() } - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { - self.db.get_devicelist_version(user_id) + pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> { + self.db + .userid_devicelistversion + .get(user_id) + .await + .deserialized() } - pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> + 'a { - self.db.all_devices_metadata(user_id) + pub fn all_devices_metadata<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = Device> + Send + 'a { + let key = (user_id, Interfix); + self.db + .userdeviceid_metadata + .stream_prefix(&key) + .ignore_err() + .map(|(_, val): (Ignore, Device)| val) } - /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } - - // Set the password to "" to indicate a deactivated account. Hashes will never - // result in an empty string, so the user will not be able to log in again. - // Systems like changing the password without logging in should check if the - // account is deactivated. - self.db.set_password(user_id, None)?; + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { + let filter_id = utils::random_string(4); - // TODO: Unhook 3PID - Ok(()) - } + let key = (user_id, &filter_id); + self.db.userfilterid_filter.put(key, Json(filter)); - /// Creates a new sync filter. Returns the filter id. - pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { - self.db.create_filter(user_id, filter) + filter_id } - pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> { - self.db.get_filter(user_id, filter_id) + pub async fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<FilterDefinition> { + let key = (user_id, filter_id); + self.db.userfilterid_filter.qry(&key).await.deserialized() } /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> { - self.db.create_openid_token(user_id, token) + use std::num::Saturating as Sat; + + let expires_in = self.services.server.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); + value.extend_from_slice(user_id.as_bytes()); + + self.db + .openidtoken_expiresatuserid + .insert(token.as_bytes(), value.as_slice()); + + Ok(expires_in) } /// Find out which user an OpenID access token belongs to. - pub fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> { self.db.find_from_openid_token(token) } + pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> { + let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else { + return Err!(Request(Unauthorized("OpenID token is unrecognised"))); + }; + + let (expires_at_bytes, user_bytes) = value.split_at(0_u64.to_be_bytes().len()); + let expires_at = u64::from_be_bytes( + expires_at_bytes + .try_into() + .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, + ); + + if expires_at < utils::millis_since_unix_epoch() { + debug_warn!("OpenID token is expired, removing"); + self.db.openidtoken_expiresatuserid.remove(token.as_bytes()); + + return Err!(Request(Unauthorized("OpenID token is expired"))); + } + + let user_string = utils::string_from_bytes(user_bytes) + .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?; + + UserId::parse(user_string).map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) + } + + /// Gets a specific user profile key + pub async fn profile_key(&self, user_id: &UserId, profile_key: &str) -> Result<serde_json::Value> { + let key = (user_id, profile_key); + self.db + .useridprofilekey_value + .qry(&key) + .await + .deserialized() + } + + /// Gets all the user's profile keys and values in an iterator + pub fn all_profile_keys<'a>( + &'a self, user_id: &'a UserId, + ) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send { + type KeyVal = ((Ignore, String), serde_json::Value); + + let prefix = (user_id, Interfix); + self.db + .useridprofilekey_value + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, key), val): KeyVal| (key, val)) + } + + /// Sets a new profile key value, removes the key if value is None + pub fn set_profile_key(&self, user_id: &UserId, profile_key: &str, profile_key_value: Option<serde_json::Value>) { + // TODO: insert to the stable MSC4175 key when it's stable + let key = (user_id, profile_key); + + if let Some(value) = profile_key_value { + self.db.useridprofilekey_value.put(key, value); + } else { + self.db.useridprofilekey_value.del(key); + } + } + + /// Get the timezone of a user. + pub async fn timezone(&self, user_id: &UserId) -> Result<String> { + // TODO: transparently migrate unstable key usage to the stable key once MSC4133 + // and MSC4175 are stable, likely a remove/insert in this block. + + // first check the unstable prefix then check the stable prefix + let unstable_key = (user_id, "us.cloke.msc4175.tz"); + let stable_key = (user_id, "m.tz"); + self.db + .useridprofilekey_value + .qry(&unstable_key) + .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) + .await + .deserialized() + } + + /// Sets a new timezone or removes it if timezone is None. + pub fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) { + // TODO: insert to the stable MSC4175 key when it's stable + let key = (user_id, "us.cloke.msc4175.tz"); + + if let Some(timezone) = timezone { + self.db.useridprofilekey_value.put_raw(key, &timezone); + } else { + self.db.useridprofilekey_value.del(key); + } + } +} + +pub fn parse_master_key(user_id: &UserId, master_key: &Raw<CrossSigningKey>) -> Result<(Vec<u8>, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) } /// Ensure that a user only sees signatures from themselves and the target user -pub fn clean_signatures<F: Fn(&UserId) -> bool>( - cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, -) -> Result<(), Error> { +fn clean_signatures<F>( + mut cross_signing_key: serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &F, +) -> Result<serde_json::Value> +where + F: Fn(&UserId) -> bool + Send + Sync, +{ if let Some(signatures) = cross_signing_key .get_mut("signatures") .and_then(|v| v.as_object_mut()) @@ -563,5 +942,12 @@ pub fn clean_signatures<F: Fn(&UserId) -> bool>( } } - Ok(()) + Ok(cross_signing_key) +} + +//TODO: this is an ABA +fn increment(db: &Arc<Map>, key: &[u8]) { + let old = db.get_blocking(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, new); } diff --git a/tests/test_results/complement/test_results.jsonl b/tests/test_results/complement/test_results.jsonl index ff695bb7401f051dcec9ade822fc855304a32f57..575a22fe838d2b0c61d205121eb6ddf001817213 100644 --- a/tests/test_results/complement/test_results.jsonl +++ b/tests/test_results/complement/test_results.jsonl @@ -225,7 +225,6 @@ {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/good_connectivity"} {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/interrupted_connectivity"} {"Action":"fail","Test":"TestToDeviceMessagesOverFederation/stopped_server"} -{"Action":"pass","Test":"TestUnbanViaInvite"} {"Action":"fail","Test":"TestUnknownEndpoints"} {"Action":"pass","Test":"TestUnknownEndpoints/Client-server_endpoints"} {"Action":"fail","Test":"TestUnknownEndpoints/Key_endpoints"}