Skip to content
Snippets Groups Projects
Unverified Commit 82e7f57b authored by Timo Kösters's avatar Timo Kösters Committed by Nyaaori
Browse files

refactor state accessor, state cache, user, uiaa

parent 3e22bbee
No related branches found
No related tags found
No related merge requests found
/// Builds a StateMap by iterating over all keys that start impl service::room::state_accessor::Data for KeyValueDatabase {
/// with state_hash, this gives the full state for the given state_hash. async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
#[tracing::instrument(skip(self))]
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = self let full_state = self
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
.pop() .pop()
...@@ -21,8 +19,7 @@ pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, ...@@ -21,8 +19,7 @@ pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64,
Ok(result) Ok(result)
} }
#[tracing::instrument(skip(self))] async fn state_full(
pub async fn state_full(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
...@@ -59,8 +56,7 @@ pub async fn state_full( ...@@ -59,8 +56,7 @@ pub async fn state_full(
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn state_get_id(
pub fn state_get_id(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
event_type: &StateEventType, event_type: &StateEventType,
...@@ -86,8 +82,7 @@ pub fn state_get_id( ...@@ -86,8 +82,7 @@ pub fn state_get_id(
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn state_get(
pub fn state_get(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
event_type: &StateEventType, event_type: &StateEventType,
...@@ -98,7 +93,7 @@ pub fn state_get( ...@@ -98,7 +93,7 @@ pub fn state_get(
} }
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid self.eventid_shorteventid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| { .map_or(Ok(None), |shorteventid| {
...@@ -116,8 +111,7 @@ pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { ...@@ -116,8 +111,7 @@ pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
} }
/// Returns the full room state. /// Returns the full room state.
#[tracing::instrument(skip(self))] async fn room_state_full(
pub async fn room_state_full(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
...@@ -129,8 +123,7 @@ pub async fn room_state_full( ...@@ -129,8 +123,7 @@ pub async fn room_state_full(
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn room_state_get_id(
pub fn room_state_get_id(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_type: &StateEventType, event_type: &StateEventType,
...@@ -144,8 +137,7 @@ pub fn room_state_get_id( ...@@ -144,8 +137,7 @@ pub fn room_state_get_id(
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn room_state_get(
pub fn room_state_get(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_type: &StateEventType, event_type: &StateEventType,
...@@ -157,4 +149,3 @@ pub fn room_state_get( ...@@ -157,4 +149,3 @@ pub fn room_state_get(
Ok(None) Ok(None)
} }
} }
impl service::room::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])?;
}
}
impl service::room::user::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))] fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
...@@ -13,8 +12,7 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> R ...@@ -13,8 +12,7 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> R
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self))] fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
...@@ -28,8 +26,7 @@ pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u ...@@ -28,8 +26,7 @@ pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u
.unwrap_or(Ok(0)) .unwrap_or(Ok(0))
} }
#[tracing::instrument(skip(self))] fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
...@@ -43,7 +40,7 @@ pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> ...@@ -43,7 +40,7 @@ pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>
.unwrap_or(Ok(0)) .unwrap_or(Ok(0))
} }
pub fn associate_token_shortstatehash( fn associate_token_shortstatehash(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
token: u64, token: u64,
...@@ -58,7 +55,7 @@ pub fn associate_token_shortstatehash( ...@@ -58,7 +55,7 @@ pub fn associate_token_shortstatehash(
.insert(&key, &shortstatehash.to_be_bytes()) .insert(&key, &shortstatehash.to_be_bytes())
} }
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
...@@ -74,8 +71,7 @@ pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<O ...@@ -74,8 +71,7 @@ pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<O
.transpose() .transpose()
} }
#[tracing::instrument(skip(self))] fn get_shared_rooms<'a>(
pub fn get_shared_rooms<'a>(
&'a self, &'a self,
users: Vec<Box<UserId>>, users: Vec<Box<UserId>>,
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { ) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> {
...@@ -111,4 +107,4 @@ pub fn get_shared_rooms<'a>( ...@@ -111,4 +107,4 @@ pub fn get_shared_rooms<'a>(
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
})) }))
} }
}
use std::{ impl service::uiaa::Data for KeyValueDatabase {
collections::BTreeMap,
sync::{Arc, RwLock},
};
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
use ruma::{
api::client::{
error::ErrorKind,
uiaa::{
AuthType, IncomingAuthData, IncomingPassword,
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo,
},
},
signatures::CanonicalJsonValue,
DeviceId, UserId,
};
use tracing::error;
use super::abstraction::Tree;
pub struct Uiaa {
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
}
impl Uiaa {
/// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create(
&self,
user_id: &UserId,
device_id: &DeviceId,
uiaainfo: &UiaaInfo,
json_body: &CanonicalJsonValue,
) -> Result<()> {
self.set_uiaa_request(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
json_body,
)?;
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
Some(uiaainfo),
)
}
pub fn try_auth(
&self,
user_id: &UserId,
device_id: &DeviceId,
auth: &IncomingAuthData,
uiaainfo: &UiaaInfo,
users: &super::users::Users,
globals: &super::globals::Globals,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = auth
.session()
.map(|session| self.get_uiaa_session(user_id, device_id, session))
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
match auth {
// Find out what the user completed
IncomingAuthData::Password(IncomingPassword {
identifier,
password,
..
}) => {
let username = match identifier {
UserIdOrLocalpart(username) => username,
_ => {
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
))
}
};
let user_id =
UserId::parse_with_server_name(username.clone(), globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;
// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
}
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password);
}
IncomingAuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
}
k => error!("type not supported: {:?}", k),
}
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
Some(&uiaainfo),
)?;
return Ok((false, uiaainfo));
}
// UIAA was successful! Remove this session and return true
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
)?;
Ok((true, uiaainfo))
}
fn set_uiaa_request( fn set_uiaa_request(
&self, &self,
user_id: &UserId, user_id: &UserId,
...@@ -162,7 +17,7 @@ fn set_uiaa_request( ...@@ -162,7 +17,7 @@ fn set_uiaa_request(
Ok(()) Ok(())
} }
pub fn get_uiaa_request( fn get_uiaa_request(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
......
pub trait Data {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))] async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>>;
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = BTreeMap::new();
let mut i = 0;
for compressed in full_state.into_iter() {
let parsed = self.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i += 1; async fn state_full(
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
#[tracing::instrument(skip(self))]
pub async fn state_full(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>;
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i = 0;
for compressed in full_state {
let (_, eventid) = self.parse_compressed_state_event(compressed)?;
if let Some(pdu) = self.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn state_get_id(
pub fn state_get_id(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>>;
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? {
Some(s) => s,
None => return Ok(None),
};
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
self.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn state_get(
pub fn state_get(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>>;
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
}
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>>;
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
)
})
})
.transpose()
})
}
/// Returns the full room state. /// Returns the full room state.
#[tracing::instrument(skip(self))] async fn room_state_full(
pub async fn room_state_full(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>;
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn room_state_get_id(
pub fn room_state_get_id(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>>;
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))] fn room_state_get(
pub fn room_state_get(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>>;
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { }
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
mod data;
pub use data::Data;
use crate::service::*;
pub struct Service<D: Data> {
db: D,
}
impl Service<_> {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = self self.db.state_full_ids(shortstatehash)
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = BTreeMap::new();
let mut i = 0;
for compressed in full_state.into_iter() {
let parsed = self.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
...@@ -26,36 +20,7 @@ pub async fn state_full( ...@@ -26,36 +20,7 @@ pub async fn state_full(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = self self.db.state_full(shortstatehash)
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i = 0;
for compressed in full_state {
let (_, eventid) = self.parse_compressed_state_event(compressed)?;
if let Some(pdu) = self.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
Ok(result)
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
...@@ -66,23 +31,7 @@ pub fn state_get_id( ...@@ -66,23 +31,7 @@ pub fn state_get_id(
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { self.db.state_get_id(shortstatehash, event_type, state_key)
Some(s) => s,
None => return Ok(None),
};
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
self.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
...@@ -93,26 +42,12 @@ pub fn state_get( ...@@ -93,26 +42,12 @@ pub fn state_get(
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)? self.db.pdu_state_get(event_id)
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
} }
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid self.db.pdu_shortstatehash(event_id)
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
)
})
})
.transpose()
})
} }
/// Returns the full room state. /// Returns the full room state.
...@@ -121,11 +56,7 @@ pub async fn room_state_full( ...@@ -121,11 +56,7 @@ pub async fn room_state_full(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.db.room_state_full(room_id)
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
...@@ -136,11 +67,7 @@ pub fn room_state_get_id( ...@@ -136,11 +67,7 @@ pub fn room_state_get_id(
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.db.room_state_get_id(room_id, event_type, state_key)
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
...@@ -151,10 +78,6 @@ pub fn room_state_get( ...@@ -151,10 +78,6 @@ pub fn room_state_get(
event_type: &StateEventType, event_type: &StateEventType,
state_key: &str, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> { ) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.db.room_state_get(room_id, event_type, state_key)
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
} }
}
pub trait Data {
fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()>;
}
mod data;
pub use data::Data;
use crate::service::*;
pub struct Service<D: Data> {
db: D,
}
impl Service<_> {
/// Update current membership data. /// Update current membership data.
#[tracing::instrument(skip(self, last_state, db))] #[tracing::instrument(skip(self, last_state, db))]
pub fn update_membership( pub fn update_membership(
...@@ -25,10 +34,6 @@ pub fn update_membership( ...@@ -25,10 +34,6 @@ pub fn update_membership(
serverroom_id.push(0xff); serverroom_id.push(0xff);
serverroom_id.extend_from_slice(room_id.as_bytes()); serverroom_id.extend_from_slice(room_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
...@@ -38,7 +43,7 @@ pub fn update_membership( ...@@ -38,7 +43,7 @@ pub fn update_membership(
// Check if the user never joined this room // Check if the user never joined this room
if !self.once_joined(user_id, room_id)? { if !self.once_joined(user_id, room_id)? {
// Add the user ID to the join list then // Add the user ID to the join list then
self.roomuseroncejoinedids.insert(&userroom_id, &[])?; self.db.mark_as_once_joined(user_id, room_id)?;
// Check if the room has a predecessor // Check if the room has a predecessor
if let Some(predecessor) = self if let Some(predecessor) = self
...@@ -116,10 +121,6 @@ pub fn update_membership( ...@@ -116,10 +121,6 @@ pub fn update_membership(
} }
} }
if update_joined_count {
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.userroomid_joined.insert(&userroom_id, &[])?; self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?; self.userroomid_invitestate.remove(&userroom_id)?;
...@@ -150,10 +151,6 @@ pub fn update_membership( ...@@ -150,10 +151,6 @@ pub fn update_membership(
return Ok(()); return Ok(());
} }
if update_joined_count {
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.userroomid_invitestate.insert( self.userroomid_invitestate.insert(
&userroom_id, &userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()) &serde_json::to_vec(&last_state.unwrap_or_default())
...@@ -167,16 +164,6 @@ pub fn update_membership( ...@@ -167,16 +164,6 @@ pub fn update_membership(
self.roomuserid_leftcount.remove(&roomuser_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?;
} }
MembershipState::Leave | MembershipState::Ban => { MembershipState::Leave | MembershipState::Ban => {
if update_joined_count
&& self
.room_members(room_id)
.chain(self.room_members_invited(room_id))
.filter_map(|r| r.ok())
.all(|u| u.server_name() != user_id.server_name())
{
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
self.userroomid_leftstate.insert( self.userroomid_leftstate.insert(
&userroom_id, &userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
...@@ -231,36 +218,6 @@ pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()> ...@@ -231,36 +218,6 @@ pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()>
.unwrap() .unwrap()
.insert(room_id.to_owned(), Arc::new(real_users)); .insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xff);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xff);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xff);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache self.appservice_in_room_cache
.write() .write()
.unwrap() .unwrap()
...@@ -714,4 +671,4 @@ pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { ...@@ -714,4 +671,4 @@ pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
} }
}
pub trait Data {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
#[tracing::instrument(skip(self))] fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>;
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
Ok(()) fn associate_token_shortstatehash(
}
#[tracing::instrument(skip(self))]
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid notification count in db."))
})
.unwrap_or(Ok(0))
}
#[tracing::instrument(skip(self))]
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
.unwrap_or(Ok(0))
}
pub fn associate_token_shortstatehash(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
token: u64, token: u64,
shortstatehash: u64, shortstatehash: u64,
) -> Result<()> { ) -> Result<()>;
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
}
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>>;
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); fn get_shared_rooms<'a>(
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
})
})
.transpose()
}
#[tracing::instrument(skip(self))]
pub fn get_shared_rooms<'a>(
&'a self, &'a self,
users: Vec<Box<UserId>>, users: Vec<Box<UserId>>,
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { ) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a>;
let iterators = users.into_iter().map(move |user_id| { }
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xff)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
+ 1; // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(|r| r.ok())
});
// We use the default compare function because keys are sorted correctly (not reversed)
Ok(utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}))
}
mod data;
pub use data::Data;
#[tracing::instrument(skip(self))] use crate::service::*;
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount pub struct Service<D: Data> {
.insert(&userroom_id, &0_u64.to_be_bytes())?; db: D,
self.userroomid_highlightcount }
.insert(&userroom_id, &0_u64.to_be_bytes())?;
Ok(()) impl Service<_> {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id)
} }
#[tracing::instrument(skip(self))]
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); self.db.notification_count(user_id, room_id)
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount
.get(&userroom_id)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid notification count in db."))
})
.unwrap_or(Ok(0))
} }
#[tracing::instrument(skip(self))]
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec(); self.db.highlight_count(user_id, room_id)
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount
.get(&userroom_id)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
.unwrap_or(Ok(0))
} }
pub fn associate_token_shortstatehash( pub fn associate_token_shortstatehash(
...@@ -49,66 +26,17 @@ pub fn associate_token_shortstatehash( ...@@ -49,66 +26,17 @@ pub fn associate_token_shortstatehash(
token: u64, token: u64,
shortstatehash: u64, shortstatehash: u64,
) -> Result<()> { ) -> Result<()> {
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); self.db.associate_token_shortstatehash(user_id, room_id)
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
} }
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); self.db.get_token_shortstatehash(room_id, token)
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
})
})
.transpose()
} }
#[tracing::instrument(skip(self))]
pub fn get_shared_rooms<'a>( pub fn get_shared_rooms<'a>(
&'a self, &'a self,
users: Vec<Box<UserId>>, users: Vec<Box<UserId>>,
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { ) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> {
let iterators = users.into_iter().map(move |user_id| { self.db.get_shared_rooms(users)
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xff)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
+ 1; // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(|r| r.ok())
});
// We use the default compare function because keys are sorted correctly (not reversed)
Ok(utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}))
} }
}
use std::{ pub trait Data {
collections::BTreeMap,
sync::{Arc, RwLock},
};
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
use ruma::{
api::client::{
error::ErrorKind,
uiaa::{
AuthType, IncomingAuthData, IncomingPassword,
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo,
},
},
signatures::CanonicalJsonValue,
DeviceId, UserId,
};
use tracing::error;
use super::abstraction::Tree;
pub struct Uiaa {
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
}
impl Uiaa {
/// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create(
&self,
user_id: &UserId,
device_id: &DeviceId,
uiaainfo: &UiaaInfo,
json_body: &CanonicalJsonValue,
) -> Result<()> {
self.set_uiaa_request(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
json_body,
)?;
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
Some(uiaainfo),
)
}
pub fn try_auth(
&self,
user_id: &UserId,
device_id: &DeviceId,
auth: &IncomingAuthData,
uiaainfo: &UiaaInfo,
users: &super::users::Users,
globals: &super::globals::Globals,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = auth
.session()
.map(|session| self.get_uiaa_session(user_id, device_id, session))
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
if uiaainfo.session.is_none() {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
}
match auth {
// Find out what the user completed
IncomingAuthData::Password(IncomingPassword {
identifier,
password,
..
}) => {
let username = match identifier {
UserIdOrLocalpart(username) => username,
_ => {
return Err(Error::BadRequest(
ErrorKind::Unrecognized,
"Identifier type not recognized.",
))
}
};
let user_id =
UserId::parse_with_server_name(username.clone(), globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
})?;
// Check if password is correct
if let Some(hash) = users.password_hash(&user_id)? {
let hash_matches =
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
kind: ErrorKind::Forbidden,
message: "Invalid username or password.".to_owned(),
});
return Ok((false, uiaainfo));
}
}
// Password was correct! Let's add it to `completed`
uiaainfo.completed.push(AuthType::Password);
}
IncomingAuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
}
k => error!("type not supported: {:?}", k),
}
// Check if a flow now succeeds
let mut completed = false;
'flows: for flow in &mut uiaainfo.flows {
for stage in &flow.stages {
if !uiaainfo.completed.contains(stage) {
continue 'flows;
}
}
// We didn't break, so this flow succeeded!
completed = true;
}
if !completed {
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
Some(&uiaainfo),
)?;
return Ok((false, uiaainfo));
}
// UIAA was successful! Remove this session and return true
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
None,
)?;
Ok((true, uiaainfo))
}
fn set_uiaa_request( fn set_uiaa_request(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
session: &str, session: &str,
request: &CanonicalJsonValue, request: &CanonicalJsonValue,
) -> Result<()> { ) -> Result<()>;
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
pub fn get_uiaa_request( fn get_uiaa_request(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
session: &str, session: &str,
) -> Option<CanonicalJsonValue> { ) -> Option<CanonicalJsonValue>;
self.userdevicesessionid_uiaarequest
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(|j| j.to_owned())
}
fn update_uiaa_session( fn update_uiaa_session(
&self, &self,
...@@ -181,47 +20,12 @@ fn update_uiaa_session( ...@@ -181,47 +20,12 @@ fn update_uiaa_session(
device_id: &DeviceId, device_id: &DeviceId,
session: &str, session: &str,
uiaainfo: Option<&UiaaInfo>, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> { ) -> Result<()>;
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session( fn get_uiaa_session(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
session: &str, session: &str,
) -> Result<UiaaInfo> { ) -> Result<UiaaInfo>;
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(
ErrorKind::Forbidden,
"UIAA session does not exist.",
))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
}
} }
use std::{ mod data;
collections::BTreeMap, pub use data::Data;
sync::{Arc, RwLock},
};
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use crate::service::*;
use ruma::{
api::client::{
error::ErrorKind,
uiaa::{
AuthType, IncomingAuthData, IncomingPassword,
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo,
},
},
signatures::CanonicalJsonValue,
DeviceId, UserId,
};
use tracing::error;
use super::abstraction::Tree; pub struct Service<D: Data> {
db: D,
pub struct Uiaa {
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
} }
impl Uiaa { impl Service<_> {
/// Creates a new Uiaa session. Make sure the session token is unique. /// Creates a new Uiaa session. Make sure the session token is unique.
pub fn create( pub fn create(
&self, &self,
...@@ -144,35 +126,13 @@ pub fn try_auth( ...@@ -144,35 +126,13 @@ pub fn try_auth(
Ok((true, uiaainfo)) Ok((true, uiaainfo))
} }
fn set_uiaa_request(
&self,
user_id: &UserId,
device_id: &DeviceId,
session: &str,
request: &CanonicalJsonValue,
) -> Result<()> {
self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(),
);
Ok(())
}
pub fn get_uiaa_request( pub fn get_uiaa_request(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
session: &str, session: &str,
) -> Option<CanonicalJsonValue> { ) -> Option<CanonicalJsonValue> {
self.userdevicesessionid_uiaarequest self.db.get_uiaa_request(user_id, device_id, session)
.read()
.unwrap()
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
.map(|j| j.to_owned())
} }
fn update_uiaa_session( fn update_uiaa_session(
...@@ -182,46 +142,6 @@ fn update_uiaa_session( ...@@ -182,46 +142,6 @@ fn update_uiaa_session(
session: &str, session: &str,
uiaainfo: Option<&UiaaInfo>, uiaainfo: Option<&UiaaInfo>,
) -> Result<()> { ) -> Result<()> {
let mut userdevicesessionid = user_id.as_bytes().to_vec(); self.db.update_uiaa_session(user_id, device_id, session, uiaainfo)
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid,
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?;
} else {
self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
}
Ok(())
}
fn get_uiaa_session(
&self,
user_id: &UserId,
device_id: &DeviceId,
session: &str,
) -> Result<UiaaInfo> {
let mut userdevicesessionid = user_id.as_bytes().to_vec();
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(device_id.as_bytes());
userdevicesessionid.push(0xff);
userdevicesessionid.extend_from_slice(session.as_bytes());
serde_json::from_slice(
&self
.userdevicesessionid_uiaainfo
.get(&userdevicesessionid)?
.ok_or(Error::BadRequest(
ErrorKind::Forbidden,
"UIAA session does not exist.",
))?,
)
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment