diff --git a/libsignal-service/src/cipher.rs b/libsignal-service/src/cipher.rs index 063b4207b..ea79eee73 100644 --- a/libsignal-service/src/cipher.rs +++ b/libsignal-service/src/cipher.rs @@ -55,17 +55,13 @@ fn debug_envelope(envelope: &Envelope) -> String { } else { format!( "Envelope {{ \ - source_address: {}, \ + source_address: {:?}, \ source_device: {:?}, \ server_guid: {:?}, \ timestamp: {:?}, \ content: {} bytes, \ }}", - if envelope.source_service_id.is_some() { - format!("{:?}", envelope.source_address()) - } else { - "unknown".to_string() - }, + envelope.source_service_id, envelope.source_device(), envelope.server_guid(), envelope.timestamp(), @@ -278,13 +274,13 @@ where ) .await?; - let sender = ServiceAddress { - uuid: Uuid::parse_str(&sender_uuid).map_err(|_| { + let sender = ServiceAddress::try_from(sender_uuid.as_str()) + .map_err(|e| { + tracing::error!("{:?}", e); SignalProtocolError::InvalidSealedSenderMessage( "invalid sender UUID".to_string(), ) - })?, - }; + })?; let needs_receipt = if envelope.source_service_id.is_some() { tracing::warn!(?envelope, "Received an unidentified delivery over an identified channel. Marking needs_receipt=false"); diff --git a/libsignal-service/src/envelope.rs b/libsignal-service/src/envelope.rs index a68aa6edd..75f0d6d89 100644 --- a/libsignal-service/src/envelope.rs +++ b/libsignal-service/src/envelope.rs @@ -3,7 +3,6 @@ use std::convert::{TryFrom, TryInto}; use aes::cipher::block_padding::Pkcs7; use aes::cipher::{BlockDecryptMut, KeyIvInit}; use prost::Message; -use uuid::Uuid; use crate::{ configuration::SignalingKey, push_service::ServiceError, @@ -135,15 +134,11 @@ impl Envelope { } pub fn source_address(&self) -> ServiceAddress { - let uuid = self - .source_service_id - .as_deref() - .map(Uuid::parse_str) - .transpose() - .expect("valid uuid checked in constructor") - .expect("source_service_id is set"); - - ServiceAddress { uuid } + match self.source_service_id.as_deref() { + Some(service_id) => ServiceAddress::try_from(service_id) + .expect("invalid ProtocolAddress UUID or prefix"), + None => panic!("source_service_id is set"), + } } } diff --git a/libsignal-service/src/profile_service.rs b/libsignal-service/src/profile_service.rs index ae17e0fc7..e23a4b51b 100644 --- a/libsignal-service/src/profile_service.rs +++ b/libsignal-service/src/profile_service.rs @@ -21,9 +21,10 @@ impl ProfileService { ) -> Result { let endpoint = match profile_key { Some(key) => { - let version = bincode::serialize( - &key.get_profile_key_version(address.aci()), - )?; + let version = + bincode::serialize(&key.get_profile_key_version( + address.aci().expect("profile by ACI ProtocolAddress"), + ))?; let version = std::str::from_utf8(&version) .expect("hex encoded profile key version"); format!("/v1/profile/{}/{}", address.uuid, version) diff --git a/libsignal-service/src/push_service.rs b/libsignal-service/src/push_service.rs index 845ef9b9d..073f40459 100644 --- a/libsignal-service/src/push_service.rs +++ b/libsignal-service/src/push_service.rs @@ -75,7 +75,7 @@ pub const STICKER_PATH: &str = "stickers/%s/full/%d"; pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55); pub const DEFAULT_DEVICE_ID: u32 = 1; -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub enum ServiceIdType { /// Account Identity (ACI) /// @@ -811,7 +811,7 @@ pub trait PushService: MaybeSend { &mut self, messages: OutgoingPushMessages, ) -> Result { - let path = format!("/v1/messages/{}", messages.recipient.uuid); + let path = format!("/v1/messages/{}", messages.destination); self.put_json( Endpoint::Service, &path, @@ -902,9 +902,9 @@ pub trait PushService: MaybeSend { profile_key: Option, ) -> Result { let endpoint = if let Some(key) = profile_key { - let version = bincode::serialize( - &key.get_profile_key_version(address.aci()), - )?; + let version = bincode::serialize(&key.get_profile_key_version( + address.aci().expect("profile by ACI ProtocolAddress"), + ))?; let version = std::str::from_utf8(&version) .expect("hex encoded profile key version"); format!("/v1/profile/{}/{}", address.uuid, version) diff --git a/libsignal-service/src/sender.rs b/libsignal-service/src/sender.rs index d838f8e33..cb0627ca7 100644 --- a/libsignal-service/src/sender.rs +++ b/libsignal-service/src/sender.rs @@ -8,7 +8,6 @@ use libsignal_protocol::{ use rand::{CryptoRng, Rng}; use tracing::{info, trace}; use tracing_futures::Instrument; -use uuid::Uuid; use crate::{ cipher::{get_preferred_protocol_address, ServiceCipher}, @@ -38,7 +37,7 @@ pub struct OutgoingPushMessage { #[derive(serde::Serialize, Debug)] pub struct OutgoingPushMessages { - pub recipient: ServiceAddress, + pub destination: uuid::Uuid, pub timestamp: u64, pub messages: Vec, pub online: bool, @@ -120,8 +119,8 @@ pub enum MessageSenderError { #[error("Proof of type {options:?} required using token {token}")] ProofRequired { token: String, options: Vec }, - #[error("Recipient not found: {uuid}")] - NotFound { uuid: Uuid }, + #[error("Recipient not found: {addr:?}")] + NotFound { addr: ServiceAddress }, } impl MessageSender @@ -500,7 +499,7 @@ where .await?; let messages = OutgoingPushMessages { - recipient, + destination: recipient.uuid, timestamp, messages, online, @@ -601,7 +600,7 @@ where Err(ServiceError::NotFoundError) => { tracing::debug!("Not found when sending a message"); return Err(MessageSenderError::NotFound { - uuid: recipient.uuid, + addr: recipient, }); }, Err(e) => { @@ -722,9 +721,22 @@ where devices.insert(DEFAULT_DEVICE_ID.into()); // never try to send messages to the sender device - if recipient.aci() == self.local_aci.aci() { - devices.remove(&self.device_id); - } + match recipient.identity { + ServiceIdType::AccountIdentity => { + if recipient.aci().is_some() + && recipient.aci() == self.local_aci.aci() + { + devices.remove(&self.device_id); + } + }, + ServiceIdType::PhoneNumberIdentity => { + if recipient.pni().is_some() + && recipient.pni() == self.local_aci.pni() + { + devices.remove(&self.device_id); + } + }, + }; for device_id in devices { trace!("sending message to device {}", device_id); @@ -836,7 +848,7 @@ where }, Err(ServiceError::NotFoundError) => { return Err(MessageSenderError::NotFound { - uuid: recipient.uuid, + addr: *recipient, }); }, Err(e) => Err(e)?, diff --git a/libsignal-service/src/service_address.rs b/libsignal-service/src/service_address.rs index 423ff3737..912c06503 100644 --- a/libsignal-service/src/service_address.rs +++ b/libsignal-service/src/service_address.rs @@ -1,9 +1,10 @@ use std::convert::TryFrom; use libsignal_protocol::{DeviceId, ProtocolAddress}; -use serde::{Deserialize, Serialize}; use uuid::Uuid; +pub use crate::push_service::ServiceIdType; + #[derive(thiserror::Error, Debug, Clone)] pub enum ParseServiceAddressError { #[error("Supplied UUID could not be parsed")] @@ -13,9 +14,10 @@ pub enum ParseServiceAddressError { NoUuid, } -#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] pub struct ServiceAddress { pub uuid: Uuid, + pub identity: ServiceIdType, } impl ServiceAddress { @@ -23,54 +25,106 @@ impl ServiceAddress { &self, device_id: impl Into, ) -> ProtocolAddress { - ProtocolAddress::new(self.uuid.to_string(), device_id.into()) + match self.identity { + ServiceIdType::AccountIdentity => { + ProtocolAddress::new(self.uuid.to_string(), device_id.into()) + }, + ServiceIdType::PhoneNumberIdentity => ProtocolAddress::new( + format!("PNI:{}", self.uuid), + device_id.into(), + ), + } } - pub fn aci(&self) -> libsignal_protocol::Aci { - libsignal_protocol::Aci::from_uuid_bytes(self.uuid.into_bytes()) + pub fn new_aci(uuid: Uuid) -> Self { + Self { + uuid, + identity: ServiceIdType::AccountIdentity, + } } - pub fn pni(&self) -> libsignal_protocol::Pni { - libsignal_protocol::Pni::from_uuid_bytes(self.uuid.into_bytes()) + pub fn new_pni(uuid: Uuid) -> Self { + Self { + uuid, + identity: ServiceIdType::PhoneNumberIdentity, + } } -} -impl From for ServiceAddress { - fn from(uuid: Uuid) -> Self { - Self { uuid } + pub fn aci(&self) -> Option { + use libsignal_protocol::Aci; + match self.identity { + ServiceIdType::AccountIdentity => { + Some(Aci::from_uuid_bytes(self.uuid.into_bytes())) + }, + ServiceIdType::PhoneNumberIdentity => None, + } + } + + pub fn pni(&self) -> Option { + use libsignal_protocol::Pni; + match self.identity { + ServiceIdType::AccountIdentity => None, + ServiceIdType::PhoneNumberIdentity => { + Some(Pni::from_uuid_bytes(self.uuid.into_bytes())) + }, + } + } + + pub fn to_service_id(&self) -> String { + match self.identity { + ServiceIdType::AccountIdentity => self.uuid.to_string(), + ServiceIdType::PhoneNumberIdentity => { + format!("PNI:{}", self.uuid) + }, + } } } -impl TryFrom<&str> for ServiceAddress { +impl TryFrom<&ProtocolAddress> for ServiceAddress { type Error = ParseServiceAddressError; - fn try_from(value: &str) -> Result { - Ok(ServiceAddress { - uuid: Uuid::parse_str(value)?, + fn try_from(addr: &ProtocolAddress) -> Result { + let value = addr.name(); + if let Some(pni) = value.strip_prefix("PNI:") { + Ok(ServiceAddress::new_pni(Uuid::parse_str(pni)?)) + } else { + Ok(ServiceAddress::new_aci(Uuid::parse_str(value)?)) + } + .map_err(|e| { + tracing::error!("Parsing ServiceAddress from {:?}", addr); + ParseServiceAddressError::InvalidUuid(e) }) } } -impl TryFrom> for ServiceAddress { +impl TryFrom<&str> for ServiceAddress { type Error = ParseServiceAddressError; - fn try_from(value: Option<&str>) -> Result { - match value.map(Uuid::parse_str) { - Some(Ok(uuid)) => Ok(ServiceAddress { uuid }), - Some(Err(e)) => Err(ParseServiceAddressError::InvalidUuid(e)), - None => Err(ParseServiceAddressError::NoUuid), + fn try_from(value: &str) -> Result { + if let Some(pni) = value.strip_prefix("PNI:") { + Ok(ServiceAddress::new_pni(Uuid::parse_str(pni)?)) + } else { + Ok(ServiceAddress::new_aci(Uuid::parse_str(value)?)) } + .map_err(|e| { + tracing::error!("Parsing ServiceAddress from '{}'", value); + ParseServiceAddressError::InvalidUuid(e) + }) } } -impl TryFrom> for ServiceAddress { +impl TryFrom<&[u8]> for ServiceAddress { type Error = ParseServiceAddressError; - fn try_from(value: Option<&[u8]>) -> Result { - match value.map(Uuid::from_slice) { - Some(Ok(uuid)) => Ok(ServiceAddress { uuid }), - Some(Err(e)) => Err(ParseServiceAddressError::InvalidUuid(e)), - None => Err(ParseServiceAddressError::NoUuid), + fn try_from(value: &[u8]) -> Result { + if let Some(pni) = value.strip_prefix(b"PNI:") { + Ok(ServiceAddress::new_pni(Uuid::from_slice(pni)?)) + } else { + Ok(ServiceAddress::new_aci(Uuid::from_slice(value)?)) } + .map_err(|e| { + tracing::error!("Parsing ServiceAddress from {:?}", value); + ParseServiceAddressError::InvalidUuid(e) + }) } } diff --git a/libsignal-service/src/websocket/sender.rs b/libsignal-service/src/websocket/sender.rs index 166d9c800..e9ea90ef6 100644 --- a/libsignal-service/src/websocket/sender.rs +++ b/libsignal-service/src/websocket/sender.rs @@ -12,7 +12,7 @@ impl SignalWebSocket { &mut self, messages: OutgoingPushMessages, ) -> Result { - let path = format!("/v1/messages/{}", messages.recipient.uuid); + let path = format!("/v1/messages/{}", messages.destination); self.put_json(&path, messages).await } @@ -21,7 +21,7 @@ impl SignalWebSocket { messages: OutgoingPushMessages, access: &UnidentifiedAccess, ) -> Result { - let path = format!("/v1/messages/{}", messages.recipient.uuid); + let path = format!("/v1/messages/{}", messages.destination); let header = format!( "Unidentified-Access-Key:{}", BASE64_RELAXED.encode(&access.key)