From 04115dd8b149af44994f02efd13d097cb3dccd60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Mon, 4 Nov 2024 15:30:08 +0100 Subject: [PATCH] fix: Use the DisplayName struct to protect against homoglyph attacks --- crates/matrix-sdk-base/src/client.rs | 12 +- .../src/deserialized_responses.rs | 10 +- crates/matrix-sdk-base/src/rooms/members.rs | 13 +- crates/matrix-sdk-base/src/rooms/normal.rs | 9 +- .../matrix-sdk-base/src/sliding_sync/mod.rs | 12 +- .../src/store/ambiguity_map.rs | 245 +++++++++++++++++- .../src/store/integration_tests.rs | 32 ++- .../matrix-sdk-base/src/store/memory_store.rs | 20 +- crates/matrix-sdk-base/src/store/mod.rs | 5 +- crates/matrix-sdk-base/src/store/traits.rs | 18 +- .../src/state_store/mod.rs | 47 +++- crates/matrix-sdk-sqlite/src/state_store.rs | 90 +++++-- 12 files changed, 408 insertions(+), 105 deletions(-) diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index a65d3e34ee7..1fcaee38b26 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -16,7 +16,7 @@ #[cfg(feature = "e2e-encryption")] use std::sync::Arc; use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, fmt, iter, ops::Deref, }; @@ -68,7 +68,7 @@ use crate::latest_event::{is_suitable_for_latest_event, LatestEvent, PossibleLat #[cfg(feature = "e2e-encryption")] use crate::RoomMemberships; use crate::{ - deserialized_responses::{RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent}, + deserialized_responses::{DisplayName, RawAnySyncOrStrippedTimelineEvent, SyncTimelineEvent}, error::{Error, Result}, event_cache::store::EventCacheStoreLock, response_processors::AccountDataProcessor, @@ -1332,7 +1332,7 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] let mut user_ids = BTreeSet::new(); - let mut ambiguity_map: BTreeMap> = BTreeMap::new(); + let mut ambiguity_map: HashMap> = Default::default(); for raw_event in &response.chunk { let member = match raw_event.deserialize() { @@ -1363,7 +1363,11 @@ impl BaseClient { if let StateEvent::Original(e) = &member { if let Some(d) = &e.content.displayname { - ambiguity_map.entry(d.clone()).or_default().insert(member.state_key().clone()); + let display_name = DisplayName::new(d); + ambiguity_map + .entry(display_name) + .or_default() + .insert(member.state_key().clone()); } } diff --git a/crates/matrix-sdk-base/src/deserialized_responses.rs b/crates/matrix-sdk-base/src/deserialized_responses.rs index 7f34d58dfae..228ba3b3e92 100644 --- a/crates/matrix-sdk-base/src/deserialized_responses.rs +++ b/crates/matrix-sdk-base/src/deserialized_responses.rs @@ -454,10 +454,12 @@ impl MemberEvent { /// /// It there is no `displayname` in the event's content, the localpart or /// the user ID is returned. - pub fn display_name(&self) -> &str { - self.original_content() - .and_then(|c| c.displayname.as_deref()) - .unwrap_or_else(|| self.user_id().localpart()) + pub fn display_name(&self) -> DisplayName { + DisplayName::new( + self.original_content() + .and_then(|c| c.displayname.as_deref()) + .unwrap_or_else(|| self.user_id().localpart()), + ) } } diff --git a/crates/matrix-sdk-base/src/rooms/members.rs b/crates/matrix-sdk-base/src/rooms/members.rs index ad5764b7c5c..a013653949b 100644 --- a/crates/matrix-sdk-base/src/rooms/members.rs +++ b/crates/matrix-sdk-base/src/rooms/members.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeSet, HashMap}, sync::Arc, }; @@ -30,7 +30,8 @@ use ruma::{ }; use crate::{ - deserialized_responses::{MemberEvent, SyncOrStrippedState}, + deserialized_responses::{DisplayName, MemberEvent, SyncOrStrippedState}, + store::ambiguity_map::is_display_name_ambiguous, MinimalRoomMemberEvent, }; @@ -67,8 +68,10 @@ impl RoomMember { } = room_info; let is_room_creator = room_creator.as_deref() == Some(event.user_id()); - let display_name_ambiguous = - users_display_names.get(event.display_name()).is_some_and(|s| s.len() > 1); + let display_name = event.display_name(); + let display_name_ambiguous = users_display_names + .get(&display_name) + .is_some_and(|s| is_display_name_ambiguous(&display_name, s)); let is_ignored = ignored_users.as_ref().is_some_and(|s| s.contains(event.user_id())); Self { @@ -245,6 +248,6 @@ pub(crate) struct MemberRoomInfo<'a> { pub(crate) power_levels: Arc>>, pub(crate) max_power_level: i64, pub(crate) room_creator: Option, - pub(crate) users_display_names: BTreeMap<&'a str, BTreeSet>, + pub(crate) users_display_names: HashMap<&'a DisplayName, BTreeSet>, pub(crate) ignored_users: Option>, } diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index 81bca1b6454..64e4b59f2fa 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -67,7 +67,7 @@ use super::{ #[cfg(feature = "experimental-sliding-sync")] use crate::latest_event::LatestEvent; use crate::{ - deserialized_responses::{MemberEvent, RawSyncOrStrippedState}, + deserialized_responses::{DisplayName, MemberEvent, RawSyncOrStrippedState}, notification_settings::RoomNotificationMode, read_receipts::RoomReadReceipts, store::{DynStateStore, Result as StoreResult, StateStoreExt}, @@ -819,8 +819,7 @@ impl Room { }) .collect::>(); - let display_names = - member_events.iter().map(|e| e.display_name().to_owned()).collect::>(); + let display_names = member_events.iter().map(|e| e.display_name()).collect::>(); let room_info = self.member_room_info(&display_names).await?; let mut members = Vec::new(); @@ -900,7 +899,7 @@ impl Room { let profile = self.store.get_profile(self.room_id(), user_id).await?; - let display_names = [event.display_name().to_owned()]; + let display_names = [event.display_name()]; let room_info = self.member_room_info(&display_names).await?; Ok(Some(RoomMember::from_parts(event, profile, presence, &room_info))) @@ -911,7 +910,7 @@ impl Room { /// Async because it can read from storage. async fn member_room_info<'a>( &self, - display_names: &'a [String], + display_names: &'a [DisplayName], ) -> StoreResult> { let max_power_level = self.max_power_level(); let room_creator = self.inner.read().creator().map(ToOwned::to_owned); diff --git a/crates/matrix-sdk-base/src/sliding_sync/mod.rs b/crates/matrix-sdk-base/src/sliding_sync/mod.rs index 9350492a098..19957e1aa08 100644 --- a/crates/matrix-sdk-base/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk-base/src/sliding_sync/mod.rs @@ -695,6 +695,10 @@ async fn cache_latest_events( changes: Option<&StateChanges>, store: Option<&Store>, ) { + use crate::{ + deserialized_responses::DisplayName, store::ambiguity_map::is_display_name_ambiguous, + }; + let mut encrypted_events = Vec::with_capacity(room.latest_encrypted_events.read().unwrap().capacity()); @@ -752,11 +756,13 @@ async fn cache_latest_events( .as_original() .and_then(|profile| profile.content.displayname.as_ref()) .and_then(|display_name| { + let display_name = DisplayName::new(display_name); + changes.ambiguity_maps.get(room.room_id()).and_then( |map_for_room| { - map_for_room - .get(display_name) - .map(|user_ids| user_ids.len() > 1) + map_for_room.get(&display_name).map(|users| { + is_display_name_ambiguous(&display_name, users) + }) }, ) }); diff --git a/crates/matrix-sdk-base/src/store/ambiguity_map.rs b/crates/matrix-sdk-base/src/store/ambiguity_map.rs index f02de1cc645..45a53a794bc 100644 --- a/crates/matrix-sdk-base/src/store/ambiguity_map.rs +++ b/crates/matrix-sdk-base/src/store/ambiguity_map.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, }; @@ -24,18 +24,18 @@ use ruma::{ }, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; -use tracing::trace; +use tracing::{instrument, trace}; use super::{DynStateStore, Result, StateChanges}; use crate::{ - deserialized_responses::{AmbiguityChange, RawMemberEvent}, + deserialized_responses::{AmbiguityChange, DisplayName, RawMemberEvent}, store::StateStoreExt, }; /// A map of users that use a certain display name. #[derive(Debug, Clone)] struct DisplayNameUsers { - display_name: String, + display_name: DisplayName, users: BTreeSet, } @@ -70,7 +70,7 @@ impl DisplayNameUsers { /// Is the display name considered to be ambiguous. fn is_ambiguous(&self) -> bool { - self.user_count() > 1 + is_display_name_ambiguous(&self.display_name, &self.users) } } @@ -82,10 +82,19 @@ fn is_member_active(membership: &MembershipState) -> bool { #[derive(Debug)] pub(crate) struct AmbiguityCache { pub store: Arc, - pub cache: BTreeMap>>, + pub cache: BTreeMap>>, pub changes: BTreeMap>, } +#[instrument(ret)] +pub(crate) fn is_display_name_ambiguous( + display_name: &DisplayName, + users_with_display_name: &BTreeSet, +) -> bool { + trace!("Checking if a display name is ambiguous"); + display_name.is_inherently_ambiguous() || users_with_display_name.len() > 1 +} + impl AmbiguityCache { /// Create a new [`AmbiguityCache`] backed by the given state store. pub fn new(store: Arc) -> Self { @@ -224,18 +233,15 @@ impl AmbiguityCache { async fn get_users_with_display_name( &mut self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result { Ok(if let Some(u) = self.cache.entry(room_id.to_owned()).or_default().get(display_name) { - DisplayNameUsers { display_name: display_name.to_owned(), users: u.clone() } + DisplayNameUsers { display_name: display_name.clone(), users: u.clone() } } else { let users_with_display_name = self.store.get_users_with_display_name(room_id, display_name).await?; - DisplayNameUsers { - display_name: display_name.to_owned(), - users: users_with_display_name, - } + DisplayNameUsers { display_name: display_name.clone(), users: users_with_display_name } }) } @@ -254,7 +260,8 @@ impl AmbiguityCache { let old_display_name = self.get_old_display_name(changes, room_id, member_event).await?; let old_map = if let Some(old_name) = old_display_name.as_deref() { - Some(self.get_users_with_display_name(room_id, old_name).await?) + let old_display_name = DisplayName::new(old_name); + Some(self.get_users_with_display_name(room_id, &old_display_name).await?) } else { None }; @@ -275,11 +282,221 @@ impl AmbiguityCache { new }; - Some(self.get_users_with_display_name(room_id, new_display_name).await?) + let new_display_name = DisplayName::new(new_display_name); + + Some(self.get_users_with_display_name(room_id, &new_display_name).await?) } else { None }; Ok((old_map, new_map)) } + + #[cfg(test)] + fn check(&self, room_id: &RoomId, display_name: &DisplayName) -> bool { + self.cache + .get(room_id) + .and_then(|display_names| { + display_names + .get(display_name) + .map(|user_ids| is_display_name_ambiguous(display_name, user_ids)) + }) + .unwrap_or_else(|| { + panic!( + "The display name {:?} should be part of the cache {:?}", + display_name, self.cache + ) + }) + } +} + +#[cfg(test)] +mod test { + use matrix_sdk_test::async_test; + use ruma::{room_id, server_name, user_id, EventId}; + use serde_json::json; + + use super::*; + use crate::store::{IntoStateStore, MemoryStore}; + + fn generate_event(user_id: &UserId, display_name: &str) -> SyncRoomMemberEvent { + let server_name = server_name!("localhost"); + serde_json::from_value(json!({ + "content": { + "displayname": display_name, + "membership": "join" + }, + "event_id": EventId::new(server_name), + "origin_server_ts": 152037280, + "sender": user_id, + "state_key": user_id, + "type": "m.room.member", + + })) + .expect("We should be able to deserialize the static member event") + } + + macro_rules! assert_ambiguity { + ( + [ $( ($user:literal, $display_name:literal) ),* ], + [ $( ($check_display_name:literal, $ambiguous:expr) ),* ] $(,)? + ) => { + assert_ambiguity!( + [ $( ($user, $display_name) ),* ], + [ $( ($check_display_name, $ambiguous) ),* ], + "The test failed the ambiguity assertions" + ) + }; + + ( + [ $( ($user:literal, $display_name:literal) ),* ], + [ $( ($check_display_name:literal, $ambiguous:expr) ),* ], + $description:literal $(,)? + ) => { + let store = MemoryStore::new(); + let mut ambiguity_cache = AmbiguityCache::new(store.into_state_store()); + + let changes = Default::default(); + let room_id = room_id!("!foo:bar"); + + macro_rules! add_display_name { + ($u:literal, $n:literal) => { + let event = generate_event(user_id!($u), $n); + + ambiguity_cache + .handle_event(&changes, room_id, &event) + .await + .expect("We should be able to handle a member event to calculate the ambiguity."); + }; + } + + macro_rules! assert_display_name_ambiguity { + ($n:literal, $a:expr) => { + let display_name = DisplayName::new($n); + + if ambiguity_cache.check(room_id, &display_name) != $a { + let foo = if $a { "be" } else { "not be" }; + panic!("{}: the display name {} should {} ambiguous", $description, $n, foo); + } + }; + } + + $( + add_display_name!($user, $display_name); + )* + + $( + assert_display_name_ambiguity!($check_display_name, $ambiguous); + )* + }; + } + + #[async_test] + async fn test_disambiguation() { + assert_ambiguity!( + [("@alice:localhost", "alice")], + [("alice", false)], + "Alice is alone in the room" + ); + + assert_ambiguity!( + [("@alice:localhost", "alice")], + [("Alice", false)], + "Alice is alone in the room and has a capitalized display name" + ); + + assert_ambiguity!( + [("@alice:localhost", "alice"), ("@bob:localhost", "alice")], + [("alice", true)], + "Alice and bob share a display name" + ); + + assert_ambiguity!( + [ + ("@alice:localhost", "alice"), + ("@bob:localhost", "alice"), + ("@carol:localhost", "carol") + ], + [("alice", true), ("carol", false)], + "Alice and Bob share a display name, while Carol is unique" + ); + + assert_ambiguity!( + [("@alice:localhost", "alice"), ("@bob:localhost", "ALICE")], + [("alice", true)], + "Alice and Bob share a display name that is differently capitalized" + ); + + assert_ambiguity!( + [("@alice:localhost", "alice"), ("@bob:localhost", "ะฐlice")], + [("alice", true)], + "Bob tries to impersonate Alice using a cyrilic ะฐ" + ); + + assert_ambiguity!( + [("@alice:localhost", "@bob:localhost"), ("@bob:localhost", "ะฐlice")], + [("@bob:localhost", true)], + "Alice tries to impersonate bob using an mxid" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐’ฎ๐’ถ๐’ฝ๐’ถ๐“ˆ๐“‡๐’ถ๐’ฝ๐“๐’ถ")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using scripture symbols" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐”–๐”ž๐”ฅ๐”ž๐”ฐ๐”ฏ๐”ž๐”ฅ๐”ฉ๐”ž")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using fraktur symbols" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "โ“ˆโ“โ“—โ“โ“ขโ“กโ“โ“—โ“›โ“")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using circled symbols" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐Ÿ…‚๐Ÿ„ฐ๐Ÿ„ท๐Ÿ„ฐ๐Ÿ…‚๐Ÿ…๐Ÿ„ฐ๐Ÿ„ท๐Ÿ„ป๐Ÿ„ฐ")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using squared symbols" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๏ผณ๏ฝ๏ฝˆ๏ฝ๏ฝ“๏ฝ’๏ฝ๏ฝˆ๏ฝŒ๏ฝ")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using big unicode letters" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "\u{202e}alharsahas")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using left to right shenanigans" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Saฬดhasrahla")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using a diacritical mark" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200B}rahla")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using a zero-width space" + ); + + assert_ambiguity!( + [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200D}rahla")], + [("Sahasrahla", true)], + "Bob tries to impersonate Alice using a zero-width space" + ); + + assert_ambiguity!( + [("@alice:localhost", "ff"), ("@bob:localhost", "\u{FB00}")], + [("ff", true)], + "Bob tries to impersonate Alice using a ligature" + ); + } } diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index 28c4e90cf4b..3fb6ac0bc9d 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -1,6 +1,6 @@ //! Trait and macro of integration tests for StateStore implementations. -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use assert_matches::assert_matches; use assert_matches2::assert_let; @@ -34,7 +34,8 @@ use ruma::{ use serde_json::{json, value::Value as JsonValue}; use super::{ - send_queue::SentRequestKey, DependentQueuedRequestKind, DynStateStore, ServerCapabilities, + send_queue::SentRequestKey, DependentQueuedRequestKind, DisplayName, DynStateStore, + ServerCapabilities, }; use crate::{ deserialized_responses::MemberEvent, @@ -141,13 +142,15 @@ impl StateStoreIntegrationTests for DynStateStore { room.handle_state_event(&topic_event); changes.add_state_event(room_id, topic_event, topic_raw); - let mut room_ambiguity_map = BTreeMap::new(); + let mut room_ambiguity_map = HashMap::new(); let mut room_profiles = BTreeMap::new(); let member_json: &JsonValue = &test_json::MEMBER; let member_event: SyncRoomMemberEvent = serde_json::from_value(member_json.clone()).unwrap(); - let displayname = member_event.as_original().unwrap().content.displayname.clone().unwrap(); + let displayname = DisplayName::new( + member_event.as_original().unwrap().content.displayname.as_ref().unwrap(), + ); room_ambiguity_map.insert(displayname.clone(), BTreeSet::from([user_id.to_owned()])); room_profiles.insert(user_id.to_owned(), (&member_event).into()); @@ -256,6 +259,8 @@ impl StateStoreIntegrationTests for DynStateStore { async fn test_populate_store(&self) -> Result<()> { let room_id = room_id(); let user_id = user_id(); + let display_name = DisplayName::new("example"); + self.populate().await?; assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some()); @@ -290,7 +295,7 @@ impl StateStoreIntegrationTests for DynStateStore { "Expected to find 1 joined user ids" ); assert_eq!( - self.get_users_with_display_name(room_id, "example").await?.len(), + self.get_users_with_display_name(room_id, &display_name).await?.len(), 2, "Expected to find 2 display names for room" ); @@ -962,6 +967,7 @@ impl StateStoreIntegrationTests for DynStateStore { async fn test_room_removal(&self) -> Result<()> { let room_id = room_id(); let user_id = user_id(); + let display_name = DisplayName::new("example"); let stripped_room_id = stripped_room_id(); self.populate().await?; @@ -990,7 +996,7 @@ impl StateStoreIntegrationTests for DynStateStore { "still joined users found" ); assert!( - self.get_users_with_display_name(room_id, "example").await?.is_empty(), + self.get_users_with_display_name(room_id, &display_name).await?.is_empty(), "still display names found" ); assert!(self @@ -1145,15 +1151,15 @@ impl StateStoreIntegrationTests for DynStateStore { async fn test_display_names_saving(&self) { let room_id = room_id!("!test_display_names_saving:localhost"); let user_id = user_id(); - let user_display_name = "User"; + let user_display_name = DisplayName::new("User"); let second_user_id = user_id!("@second:localhost"); let third_user_id = user_id!("@third:localhost"); - let other_display_name = "Raoul"; - let unknown_display_name = "Unknown"; + let other_display_name = DisplayName::new("Raoul"); + let unknown_display_name = DisplayName::new("Unknown"); // No event in store. let mut display_names = vec![user_display_name.to_owned()]; - let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap(); + let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap(); assert!(users.is_empty()); let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap(); assert!(names.is_empty()); @@ -1167,7 +1173,7 @@ impl StateStoreIntegrationTests for DynStateStore { .insert(user_display_name.to_owned(), [user_id.to_owned()].into()); self.save_changes(&changes).await.unwrap(); - let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap(); + let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap(); assert_eq!(users.len(), 1); let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap(); assert_eq!(names.len(), 1); @@ -1182,9 +1188,9 @@ impl StateStoreIntegrationTests for DynStateStore { self.save_changes(&changes).await.unwrap(); display_names.push(other_display_name.to_owned()); - let users = self.get_users_with_display_name(room_id, user_display_name).await.unwrap(); + let users = self.get_users_with_display_name(room_id, &user_display_name).await.unwrap(); assert_eq!(users.len(), 1); - let users = self.get_users_with_display_name(room_id, other_display_name).await.unwrap(); + let users = self.get_users_with_display_name(room_id, &other_display_name).await.unwrap(); assert_eq!(users.len(), 2); let names = self.get_users_with_display_names(room_id, &display_names).await.unwrap(); assert_eq!(names.len(), 2); diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index 60701a3e5f6..ac2951be983 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -42,7 +42,8 @@ use super::{ StateChanges, StateStore, StoreError, }; use crate::{ - deserialized_responses::RawAnySyncOrStrippedState, store::QueueWedgeError, + deserialized_responses::{DisplayName, RawAnySyncOrStrippedState}, + store::QueueWedgeError, MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue, }; @@ -61,7 +62,7 @@ pub struct MemoryStore { utd_hook_manager_data: StdRwLock>, account_data: StdRwLock>>, profiles: StdRwLock>>, - display_names: StdRwLock>>>, + display_names: StdRwLock>>>, members: StdRwLock>>, room_info: StdRwLock>, room_state: StdRwLock< @@ -701,7 +702,7 @@ impl StateStore for MemoryStore { async fn get_users_with_display_name( &self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result> { Ok(self .display_names @@ -715,21 +716,18 @@ impl StateStore for MemoryStore { async fn get_users_with_display_names<'a>( &self, room_id: &RoomId, - display_names: &'a [String], - ) -> Result>> { + display_names: &'a [DisplayName], + ) -> Result>> { if display_names.is_empty() { - return Ok(BTreeMap::new()); + return Ok(HashMap::new()); } let read_guard = &self.display_names.read().unwrap(); let Some(room_names) = read_guard.get(room_id) else { - return Ok(BTreeMap::new()); + return Ok(HashMap::new()); }; - Ok(display_names - .iter() - .filter_map(|n| room_names.get(n).map(|d| (n.as_str(), d.clone()))) - .collect()) + Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect()) } async fn get_account_data_event( diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index c33e3259b23..fd31b917734 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -21,7 +21,7 @@ //! store. use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, fmt, ops::Deref, result::Result as StdResult, @@ -58,6 +58,7 @@ use tokio::sync::{broadcast, Mutex, RwLock}; use tracing::warn; use crate::{ + deserialized_responses::DisplayName, event_cache::store as event_cache_store, rooms::{normal::RoomInfoNotableUpdate, RoomInfo, RoomState}, MinimalRoomMemberEvent, Room, RoomStateFilter, SessionMeta, @@ -384,7 +385,7 @@ pub struct StateChanges { /// A map from room id to a map of a display name and a set of user ids that /// share that display name in the given room. - pub ambiguity_maps: BTreeMap>>, + pub ambiguity_maps: BTreeMap>>, } impl StateChanges { diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index f60189dd3d3..04dad0e24bd 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -14,7 +14,7 @@ use std::{ borrow::Borrow, - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, fmt, sync::Arc, }; @@ -46,7 +46,9 @@ use super::{ StoreError, }; use crate::{ - deserialized_responses::{RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState}, + deserialized_responses::{ + DisplayName, RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState, + }, MinimalRoomMemberEvent, RoomInfo, RoomMemberships, }; @@ -206,7 +208,7 @@ pub trait StateStore: AsyncTraitDeps { async fn get_users_with_display_name( &self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result, Self::Error>; /// Get all the users that use the given display names in the given room. @@ -219,8 +221,8 @@ pub trait StateStore: AsyncTraitDeps { async fn get_users_with_display_names<'a>( &self, room_id: &RoomId, - display_names: &'a [String], - ) -> Result>, Self::Error>; + display_names: &'a [DisplayName], + ) -> Result>, Self::Error>; /// Get an event out of the account data store. /// @@ -567,7 +569,7 @@ impl StateStore for EraseStateStoreError { async fn get_users_with_display_name( &self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result, Self::Error> { self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into) } @@ -575,8 +577,8 @@ impl StateStore for EraseStateStoreError { async fn get_users_with_display_names<'a>( &self, room_id: &RoomId, - display_names: &'a [String], - ) -> Result>, Self::Error> { + display_names: &'a [DisplayName], + ) -> Result>, Self::Error> { self.0.get_users_with_display_names(room_id, display_names).await.map_err(Into::into) } diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs index 372a179c9f9..6e6193e4314 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, BTreeSet, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, sync::Arc, }; @@ -23,7 +23,7 @@ use gloo_utils::format::JsValueSerdeExt; use growable_bloom_filter::GrowableBloom; use indexed_db_futures::prelude::*; use matrix_sdk_base::{ - deserialized_responses::RawAnySyncOrStrippedState, + deserialized_responses::{DisplayName, RawAnySyncOrStrippedState}, store::{ ChildTransactionId, ComposerDraft, DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequest, QueuedRequestKind, SentRequestKey, SerializableEventContent, @@ -665,7 +665,15 @@ impl_state_store!({ let store = tx.object_store(keys::DISPLAY_NAMES)?; for (room_id, ambiguity_maps) in &changes.ambiguity_maps { for (display_name, map) in ambiguity_maps { - let key = self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)); + let key = self.encode_key( + keys::DISPLAY_NAMES, + ( + room_id, + display_name + .as_normalized_str() + .unwrap_or_else(|| display_name.as_raw_str()), + ), + ); store.put_key_val(&key, &self.serialize_value(&map)?)?; } @@ -1122,12 +1130,18 @@ impl_state_store!({ async fn get_users_with_display_name( &self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result> { self.inner .transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)? .object_store(keys::DISPLAY_NAMES)? - .get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))? + .get(&self.encode_key( + keys::DISPLAY_NAMES, + ( + room_id, + display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()), + ), + ))? .await? .map(|f| self.deserialize_value::>(&f)) .unwrap_or_else(|| Ok(Default::default())) @@ -1136,10 +1150,12 @@ impl_state_store!({ async fn get_users_with_display_names<'a>( &self, room_id: &RoomId, - display_names: &'a [String], - ) -> Result>> { + display_names: &'a [DisplayName], + ) -> Result>> { + let mut map = HashMap::new(); + if display_names.is_empty() { - return Ok(BTreeMap::new()); + return Ok(map); } let txn = self @@ -1147,15 +1163,24 @@ impl_state_store!({ .transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)?; let store = txn.object_store(keys::DISPLAY_NAMES)?; - let mut map = BTreeMap::new(); for display_name in display_names { if let Some(user_ids) = store - .get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))? + .get( + &self.encode_key( + keys::DISPLAY_NAMES, + ( + room_id, + display_name + .as_normalized_str() + .unwrap_or_else(|| display_name.as_raw_str()), + ), + ), + )? .await? .map(|f| self.deserialize_value::>(&f)) .transpose()? { - map.insert(display_name.as_ref(), user_ids); + map.insert(display_name, user_ids); } } diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index ab152289059..02aeb0f545a 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -1,6 +1,6 @@ use std::{ borrow::Cow, - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashMap}, fmt, iter, path::Path, sync::Arc, @@ -9,7 +9,7 @@ use std::{ use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime}; use matrix_sdk_base::{ - deserialized_responses::{RawAnySyncOrStrippedState, SyncOrStrippedState}, + deserialized_responses::{DisplayName, RawAnySyncOrStrippedState, SyncOrStrippedState}, store::{ migration_helpers::RoomInfoV1, ChildTransactionId, DependentQueuedRequest, DependentQueuedRequestKind, QueueWedgeError, QueuedRequest, QueuedRequestKind, @@ -1305,13 +1305,34 @@ impl StateStore for SqliteStateStore { let room_id = this.encode_key(keys::DISPLAY_NAME, room_id); for (name, user_ids) in display_names { - let name = this.encode_key(keys::DISPLAY_NAME, name); + let encoded_name = this.encode_key( + keys::DISPLAY_NAME, + name.as_normalized_str().unwrap_or_else(|| name.as_raw_str()), + ); let data = this.serialize_json(&user_ids)?; if user_ids.is_empty() { - txn.remove_display_name(&room_id, &name)?; + txn.remove_display_name(&room_id, &encoded_name)?; + + // We can't do a migration to merge the previously distinct buckets of + // user IDs since the display names themselves are hashed before they + // are persisted in the store. So the store will always retain two + // buckets: one for raw display names and one for normalised ones. + // + // We therefore do the next best thing, which is a sort of a soft + // migration: we fetch both the raw and normalised buckets, then merge + // the user IDs contained in them into a separate, temporary merged + // bucket. The SDK then operates on the merged buckets exclusively. See + // the comment in `get_users_with_display_names` for details. + // + // If the merged bucket is empty, that must mean that both the raw and + // normalised buckets were also empty, so we can remove both from the + // store. + let raw_name = this.encode_key(keys::DISPLAY_NAME, name.as_raw_str()); + txn.remove_display_name(&room_id, &raw_name)?; } else { - txn.set_display_name(&room_id, &name, &data)?; + // We only create new buckets with the normalized display name. + txn.set_display_name(&room_id, &encoded_name, &data)?; } } } @@ -1500,10 +1521,13 @@ impl StateStore for SqliteStateStore { async fn get_users_with_display_name( &self, room_id: &RoomId, - display_name: &str, + display_name: &DisplayName, ) -> Result> { let room_id = self.encode_key(keys::DISPLAY_NAME, room_id); - let names = vec![self.encode_key(keys::DISPLAY_NAME, display_name)]; + let names = vec![self.encode_key( + keys::DISPLAY_NAME, + display_name.as_normalized_str().unwrap_or_else(|| display_name.as_raw_str()), + )]; Ok(self .acquire() @@ -1520,33 +1544,49 @@ impl StateStore for SqliteStateStore { async fn get_users_with_display_names<'a>( &self, room_id: &RoomId, - display_names: &'a [String], - ) -> Result>> { + display_names: &'a [DisplayName], + ) -> Result>> { + let mut result = HashMap::new(); + if display_names.is_empty() { - return Ok(BTreeMap::new()); + return Ok(result); } let room_id = self.encode_key(keys::DISPLAY_NAME, room_id); let mut names_map = display_names .iter() - .map(|n| (self.encode_key(keys::DISPLAY_NAME, n), n.as_ref())) + .flat_map(|display_name| { + // We encode the display name as the `raw_str()` and the normalized string. + // + // This is for compatibility reasons since: + // 1. Previously "Alice" and "alice" were considered to be distinct display + // names, while we now consider them to be the same so we need to merge the + // previously distinct buckets of user IDs. + // 2. We can't do a migration to merge the previously distinct buckets of user + // IDs since the display names itself are hashed before they are persisted + // in the store. + let raw = + (self.encode_key(keys::DISPLAY_NAME, display_name.as_raw_str()), display_name); + let normalized = display_name.as_normalized_str().map(|normalized| { + (self.encode_key(keys::DISPLAY_NAME, normalized), display_name) + }); + + iter::once(raw).chain(normalized.into_iter()) + }) .collect::>(); let names = names_map.keys().cloned().collect(); - self.acquire() - .await? - .get_display_names(room_id, names) - .await? - .into_iter() - .map(|(name, data)| { - Ok(( - names_map - .remove(name.as_slice()) - .expect("returned display names were requested"), - self.deserialize_json(&data)?, - )) - }) - .collect::>>() + for (name, data) in + self.acquire().await?.get_display_names(room_id, names).await?.into_iter() + { + let display_name = + names_map.remove(name.as_slice()).expect("returned display names were requested"); + let user_ids: BTreeSet<_> = self.deserialize_json(&data)?; + + result.entry(display_name).or_insert_with(BTreeSet::new).extend(user_ids); + } + + Ok(result) } async fn get_account_data_event(