diff --git a/src/claims.rs b/src/claims.rs index 3fa1119..00591af 100644 --- a/src/claims.rs +++ b/src/claims.rs @@ -1,4 +1,4 @@ -use crate::helpers::{timestamp_to_utc, utc_to_seconds, Boolean, FlattenFilter, Timestamp}; +use crate::helpers::{Boolean, DeserializeMapField, FlattenFilter, Timestamp}; use crate::types::localized::split_language_tag_key; use crate::{ AddressCountry, AddressLocality, AddressPostalCode, AddressRegion, EndUserBirthday, @@ -214,31 +214,128 @@ where where V: MapAccess<'de>, { - deserialize_fields! { - map { - [sub] - [LanguageTag(name)] - [LanguageTag(given_name)] - [LanguageTag(family_name)] - [LanguageTag(middle_name)] - [LanguageTag(nickname)] - [Option(preferred_username)] - [LanguageTag(profile)] - [LanguageTag(picture)] - [LanguageTag(website)] - [Option(email)] - [Option(Boolean(email_verified))] - [Option(gender)] - [Option(birthday)] - [Option(birthdate)] - [Option(zoneinfo)] - [Option(locale)] - [Option(phone_number)] - [Option(Boolean(phone_number_verified))] - [Option(address)] - [Option(DateTime(Seconds(updated_at)))] - } + // NB: The non-localized fields are actually Option> here so that we can + // distinguish between omitted fields and fields explicitly set to `null`. The + // latter is necessary so that we can detect duplicate fields (e.g., if a key is + // present both with a null value and a non-null value, that's an error). + let mut sub = None; + let mut name = None; + let mut given_name = None; + let mut family_name = None; + let mut middle_name = None; + let mut nickname = None; + let mut preferred_username = None; + let mut profile = None; + let mut picture = None; + let mut website = None; + let mut email = None; + let mut email_verified = None; + let mut gender = None; + let mut birthday = None; + let mut birthdate = None; + let mut zoneinfo = None; + let mut locale = None; + let mut phone_number = None; + let mut phone_number_verified = None; + let mut address = None; + let mut updated_at = None; + + macro_rules! field_case { + ($field:ident, $typ:ty, $language_tag:ident) => {{ + $field = Some(<$typ>::deserialize_map_field( + &mut map, + stringify!($field), + $language_tag, + $field, + )?); + }}; + } + + while let Some(key) = map.next_key::()? { + let (field_name, language_tag) = split_language_tag_key(&key); + match field_name { + "sub" => field_case!(sub, SubjectIdentifier, language_tag), + "name" => field_case!(name, LocalizedClaim>, language_tag), + "given_name" => { + field_case!(given_name, LocalizedClaim>, language_tag) + } + "family_name" => { + field_case!(family_name, LocalizedClaim>, language_tag) + } + "middle_name" => { + field_case!(middle_name, LocalizedClaim>, language_tag) + } + "nickname" => { + field_case!(nickname, LocalizedClaim>, language_tag) + } + "preferred_username" => { + field_case!(preferred_username, Option<_>, language_tag) + } + "profile" => { + field_case!(profile, LocalizedClaim>, language_tag) + } + "picture" => { + field_case!(picture, LocalizedClaim>, language_tag) + } + "website" => { + field_case!(website, LocalizedClaim>, language_tag) + } + "email" => field_case!(email, Option<_>, language_tag), + "email_verified" => { + field_case!(email_verified, Option, language_tag) + } + "gender" => field_case!(gender, Option<_>, language_tag), + "birthday" => field_case!(birthday, Option<_>, language_tag), + "birthdate" => field_case!(birthdate, Option<_>, language_tag), + "zoneinfo" => field_case!(zoneinfo, Option<_>, language_tag), + "locale" => field_case!(locale, Option<_>, language_tag), + "phone_number" => field_case!(phone_number, Option<_>, language_tag), + "phone_number_verified" => { + field_case!(phone_number_verified, Option, language_tag) + } + "address" => field_case!(address, Option<_>, language_tag), + "updated_at" => field_case!(updated_at, Option, language_tag), + // Ignore unknown fields. + _ => { + map.next_value::()?; + continue; + } + }; } + + Ok(StandardClaims { + sub: sub.ok_or_else(|| serde::de::Error::missing_field("sub"))?, + name: name.and_then(LocalizedClaim::flatten_or_none), + given_name: given_name.and_then(LocalizedClaim::flatten_or_none), + family_name: family_name.and_then(LocalizedClaim::flatten_or_none), + middle_name: middle_name.and_then(LocalizedClaim::flatten_or_none), + nickname: nickname.and_then(LocalizedClaim::flatten_or_none), + preferred_username: preferred_username.flatten(), + profile: profile.and_then(LocalizedClaim::flatten_or_none), + picture: picture.and_then(LocalizedClaim::flatten_or_none), + website: website.and_then(LocalizedClaim::flatten_or_none), + email: email.flatten(), + email_verified: email_verified.flatten().map(Boolean::into_inner), + gender: gender.flatten(), + birthday: birthday.flatten(), + birthdate: birthdate.flatten(), + zoneinfo: zoneinfo.flatten(), + locale: locale.flatten(), + phone_number: phone_number.flatten(), + phone_number_verified: phone_number_verified.flatten().map(Boolean::into_inner), + address: address.flatten(), + updated_at: updated_at + .flatten() + .map(|sec| { + sec.to_utc().map_err(|_| { + serde::de::Error::custom(format!( + "failed to parse `{sec}` as UTC datetime (in seconds) for key \ + `updated_at`" + )) + }) + }) + .transpose()?, + }) } } deserializer.deserialize_map(ClaimsVisitor(PhantomData)) @@ -280,3 +377,106 @@ where } } } + +#[cfg(test)] +mod tests { + use crate::core::CoreGenderClaim; + use crate::StandardClaims; + + // The spec states (https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse): + // "If a Claim is not returned, that Claim Name SHOULD be omitted from the JSON object + // representing the Claims; it SHOULD NOT be present with a null or empty string value." + // However, we still aim to support identity providers that disregard this suggestion. + #[test] + fn test_null_optional_claims() { + let claims = serde_json::from_str::>( + r#"{ + "sub": "24400320", + "name": null, + "given_name": null, + "family_name": null, + "middle_name": null, + "nickname": null, + "preferred_username": null, + "profile": null, + "picture": null, + "website": null, + "email": null, + "email_verified": null, + "gender": null, + "birthday": null, + "birthdate": null, + "zoneinfo": null, + "locale": null, + "phone_number": null, + "phone_number_verified": null, + "address": null, + "updated_at": null + }"#, + ) + .expect("should deserialize successfully"); + + assert_eq!(claims.subject().as_str(), "24400320"); + assert_eq!(claims.name(), None); + } + + fn expect_err_prefix( + result: Result, serde_json::Error>, + expected_prefix: &str, + ) { + let err_str = result.expect_err("deserialization should fail").to_string(); + assert!( + err_str.starts_with(expected_prefix), + "error message should begin with `{}`: {}", + expected_prefix, + err_str, + ) + } + + #[test] + fn test_duplicate_claims() { + expect_err_prefix( + serde_json::from_str( + r#"{ + "sub": "24400320", + "sub": "24400321" + }"#, + ), + "duplicate field `sub` at line", + ); + + expect_err_prefix( + serde_json::from_str( + r#"{ + "name": null, + "sub": "24400320", + "name": "foo", + }"#, + ), + "duplicate field `name` at line", + ); + + expect_err_prefix( + serde_json::from_str( + r#"{ + "name#en": null, + "sub": "24400320", + "name#en": "foo", + }"#, + ), + "duplicate field `name#en` at line", + ); + } + + #[test] + fn test_err_field_name() { + expect_err_prefix( + serde_json::from_str( + r#"{ + "sub": 24400320 + }"#, + ), + "sub: invalid type: integer `24400320`, expected a string at line", + ); + } +} diff --git a/src/helpers.rs b/src/helpers.rs index 76f5457..90c3090 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -1,9 +1,13 @@ +use crate::types::localized::join_language_tag_key; +use crate::{LanguageTag, LocalizedClaim}; + use chrono::{DateTime, TimeZone, Utc}; use serde::de::value::MapDeserializer; -use serde::de::{DeserializeOwned, Deserializer, MapAccess, Visitor}; +use serde::de::{DeserializeOwned, Deserializer, Error, MapAccess, Visitor}; use serde::{Deserialize, Serialize, Serializer}; use serde_json::from_value; use serde_value::ValueDeserializer; +use serde_with::{DeserializeAs, SerializeAs}; use std::cmp::PartialEq; use std::fmt::{Debug, Display, Formatter, Result as FormatterResult}; @@ -14,8 +18,6 @@ where T: DeserializeOwned, D: Deserializer<'de>, { - use serde::de::Error; - let value: serde_json::Value = Deserialize::deserialize(deserializer)?; match from_value::>(value.clone()) { Ok(val) => Ok(val), @@ -33,8 +35,6 @@ where T: DeserializeOwned, D: Deserializer<'de>, { - use serde::de::Error; - let value: serde_json::Value = Deserialize::deserialize(deserializer)?; match from_value::>>(value.clone()) { Ok(val) => Ok(val), @@ -62,6 +62,79 @@ where } } +pub trait DeserializeMapField: Sized { + fn deserialize_map_field<'de, V>( + map: &mut V, + field_name: &'static str, + language_tag: Option, + field_value: Option, + ) -> Result + where + V: MapAccess<'de>; +} + +impl DeserializeMapField for T +where + T: DeserializeOwned, +{ + fn deserialize_map_field<'de, V>( + map: &mut V, + field_name: &'static str, + language_tag: Option, + field_value: Option, + ) -> Result + where + V: MapAccess<'de>, + { + if field_value.is_some() { + return Err(serde::de::Error::duplicate_field(field_name)); + } else if let Some(language_tag) = language_tag { + return Err(serde::de::Error::custom(format!( + "unexpected language tag `{language_tag}` for key `{field_name}`" + ))); + } + map.next_value().map_err(|err| { + V::Error::custom(format!( + "{}: {err}", + join_language_tag_key(field_name, language_tag.as_ref()) + )) + }) + } +} + +impl DeserializeMapField for LocalizedClaim +where + T: DeserializeOwned, +{ + fn deserialize_map_field<'de, V>( + map: &mut V, + field_name: &'static str, + language_tag: Option, + field_value: Option, + ) -> Result + where + V: MapAccess<'de>, + { + let mut localized_claim = field_value.unwrap_or_default(); + if localized_claim.contains_key(language_tag.as_ref()) { + return Err(serde::de::Error::custom(format!( + "duplicate field `{}`", + join_language_tag_key(field_name, language_tag.as_ref()) + ))); + } + + let localized_value = map.next_value().map_err(|err| { + V::Error::custom(format!( + "{}: {err}", + join_language_tag_key(field_name, language_tag.as_ref()) + )) + })?; + localized_claim.insert(language_tag, localized_value); + + Ok(localized_claim) + } +} + // Some providers return boolean values as strings. Provide support for // parsing using stdlib. #[cfg(feature = "accept-string-booleans")] @@ -320,7 +393,11 @@ pub(crate) struct Boolean( )] pub bool, ); - +impl Boolean { + pub(crate) fn into_inner(self) -> bool { + self.0 + } +} impl Display for Boolean { fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { Display::fmt(&self.0, f) @@ -335,105 +412,72 @@ pub(crate) enum Timestamp { #[cfg(feature = "accept-rfc3339-timestamps")] Rfc3339(String), } +impl Timestamp { + // The spec is ambiguous about whether seconds should be expressed as integers, or + // whether floating-point values are allowed. For compatibility with a wide range of + // clients, we round down to the nearest second. + pub(crate) fn from_utc(utc: &DateTime) -> Self { + Timestamp::Seconds(utc.timestamp().into()) + } -impl Display for Timestamp { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + pub(crate) fn to_utc(&self) -> Result, ()> { match self { - Timestamp::Seconds(seconds) => Display::fmt(seconds, f), + Timestamp::Seconds(seconds) => { + let (secs, nsecs) = if seconds.is_i64() { + (seconds.as_i64().ok_or(())?, 0u32) + } else { + let secs_f64 = seconds.as_f64().ok_or(())?; + let secs = secs_f64.floor(); + ( + secs as i64, + ((secs_f64 - secs) * 1_000_000_000.).floor() as u32, + ) + }; + Utc.timestamp_opt(secs, nsecs).single().ok_or(()) + } #[cfg(feature = "accept-rfc3339-timestamps")] - Timestamp::Rfc3339(iso) => Display::fmt(iso, f), + Timestamp::Rfc3339(iso) => { + let datetime = DateTime::parse_from_rfc3339(iso).map_err(|_| ())?; + Ok(datetime.into()) + } } } } -pub(crate) fn timestamp_to_utc(timestamp: &Timestamp) -> Result, ()> { - match timestamp { - Timestamp::Seconds(seconds) => { - let (secs, nsecs) = if seconds.is_i64() { - (seconds.as_i64().ok_or(())?, 0u32) - } else { - let secs_f64 = seconds.as_f64().ok_or(())?; - let secs = secs_f64.floor(); - ( - secs as i64, - ((secs_f64 - secs) * 1_000_000_000.).floor() as u32, - ) - }; - Utc.timestamp_opt(secs, nsecs).single().ok_or(()) - } - #[cfg(feature = "accept-rfc3339-timestamps")] - Timestamp::Rfc3339(iso) => { - let datetime = DateTime::parse_from_rfc3339(iso).map_err(|_| ())?; - Ok(datetime.into()) +impl Display for Timestamp { + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + match self { + Timestamp::Seconds(seconds) => Display::fmt(seconds, f), + #[cfg(feature = "accept-rfc3339-timestamps")] + Timestamp::Rfc3339(iso) => Display::fmt(iso, f), } } } -pub mod serde_utc_seconds { - use crate::helpers::{timestamp_to_utc, utc_to_seconds, Timestamp}; - - use chrono::{DateTime, Utc}; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +impl<'de> DeserializeAs<'de, DateTime> for Timestamp { + fn deserialize_as(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let seconds: Timestamp = Deserialize::deserialize(deserializer)?; - timestamp_to_utc(&seconds).map_err(|_| { + seconds.to_utc().map_err(|_| { serde::de::Error::custom(format!( "failed to parse `{}` as UTC datetime (in seconds)", seconds )) }) } - - pub fn serialize(v: &DateTime, serializer: S) -> Result - where - S: Serializer, - { - utc_to_seconds(v).serialize(serializer) - } } -pub mod serde_utc_seconds_opt { - use crate::helpers::{timestamp_to_utc, utc_to_seconds, Timestamp}; - - use chrono::{DateTime, Utc}; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - - pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> - where - D: Deserializer<'de>, - { - let seconds: Option = Deserialize::deserialize(deserializer)?; - seconds - .map(|sec| { - timestamp_to_utc(&sec).map_err(|_| { - serde::de::Error::custom(format!( - "failed to parse `{}` as UTC datetime (in seconds)", - sec - )) - }) - }) - .transpose() - } - - pub fn serialize(v: &Option>, serializer: S) -> Result +impl SerializeAs> for Timestamp { + fn serialize_as(source: &DateTime, serializer: S) -> Result where S: Serializer, { - v.map(|sec| utc_to_seconds(&sec)).serialize(serializer) + Timestamp::from_utc(source).serialize(serializer) } } -// The spec is ambiguous about whether seconds should be expressed as integers, or -// whether floating-point values are allowed. For compatibility with a wide range of -// clients, we round down to the nearest second. -pub(crate) fn utc_to_seconds(utc: &DateTime) -> Timestamp { - Timestamp::Seconds(utc.timestamp().into()) -} - new_type![ #[derive(Deserialize, Hash, Serialize)] pub(crate) Base64UrlEncodedBytes( diff --git a/src/id_token/mod.rs b/src/id_token/mod.rs index 706eeaf..a900191 100644 --- a/src/id_token/mod.rs +++ b/src/id_token/mod.rs @@ -1,6 +1,4 @@ -use crate::helpers::{ - deserialize_string_or_vec, serde_utc_seconds, serde_utc_seconds_opt, FilteredFlatten, -}; +use crate::helpers::{deserialize_string_or_vec, FilteredFlatten, Timestamp}; use crate::jwt::JsonWebTokenAccess; use crate::jwt::{JsonWebTokenError, JsonWebTokenJsonPayloadSerde}; use crate::types::jwk::JwsSigningAlgorithm; @@ -19,7 +17,7 @@ use crate::{ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use serde_with::skip_serializing_none; +use serde_with::{serde_as, skip_serializing_none}; use std::fmt::Debug; use std::str::FromStr; @@ -198,6 +196,7 @@ where any(test, feature = "timing-resistant-secret-traits"), derive(PartialEq) )] +#[serde_as] #[skip_serializing_none] #[derive(Clone, Debug, Deserialize, Serialize)] pub struct IdTokenClaims @@ -216,11 +215,13 @@ where deserialize_with = "deserialize_string_or_vec" )] audiences: Vec, - #[serde(rename = "exp", with = "serde_utc_seconds")] + #[serde_as(as = "Timestamp")] + #[serde(rename = "exp")] expiration: DateTime, - #[serde(rename = "iat", with = "serde_utc_seconds")] + #[serde_as(as = "Timestamp")] + #[serde(rename = "iat")] issue_time: DateTime, - #[serde(default, with = "serde_utc_seconds_opt")] + #[serde_as(as = "Option")] auth_time: Option>, nonce: Option, #[serde(rename = "acr")] diff --git a/src/id_token/tests.rs b/src/id_token/tests.rs index 304ab8c..d0161fb 100644 --- a/src/id_token/tests.rs +++ b/src/id_token/tests.rs @@ -15,6 +15,7 @@ use crate::{ use chrono::{TimeZone, Utc}; use oauth2::TokenResponse; +use pretty_assertions::assert_eq; use serde::{Deserialize, Serialize}; use url::Url; diff --git a/src/jwt/mod.rs b/src/jwt/mod.rs index 14a5122..1e58007 100644 --- a/src/jwt/mod.rs +++ b/src/jwt/mod.rs @@ -439,12 +439,7 @@ where JS: JwsSigningAlgorithm, P: Debug + DeserializeOwned + Serialize, S: JsonWebTokenPayloadSerde

, - >( - PhantomData, - PhantomData, - PhantomData

, - PhantomData, - ); + >(PhantomData<(JE, JS, P, S)>); impl<'de, JE, JS, P, S> Visitor<'de> for JsonWebTokenVisitor where JE: JweContentEncryptionAlgorithm, @@ -514,12 +509,7 @@ where }) } } - deserializer.deserialize_str(JsonWebTokenVisitor( - PhantomData, - PhantomData, - PhantomData, - PhantomData, - )) + deserializer.deserialize_str(JsonWebTokenVisitor(PhantomData)) } } impl Serialize for JsonWebToken diff --git a/src/macros.rs b/src/macros.rs index 0f4793a..bb3caf9 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -427,217 +427,6 @@ macro_rules! new_url_type { }; } -macro_rules! deserialize_fields { - (@field_str Option(Seconds($field:ident))) => { stringify![$field] }; - (@field_str Option(DateTime(Seconds($field:ident)))) => { stringify![$field] }; - (@field_str Option(Boolean($field:ident))) => { stringify![$field] }; - (@field_str Option($field:ident)) => { stringify![$field] }; - (@field_str LanguageTag($field:ident)) => { stringify![$field] }; - (@field_str $field:ident) => { stringify![$field] }; - (@let_none Option(Seconds($field:ident))) => { let mut $field = None; }; - (@let_none Option(DateTime(Seconds($field:ident)))) => { let mut $field = None; }; - (@let_none Option(Boolean($field:ident))) => { let mut $field = None; }; - (@let_none Option($field:ident)) => { let mut $field = None; }; - (@let_none LanguageTag($field:ident)) => { let mut $field = None; }; - (@let_none $field:ident) => { let mut $field = None; }; - (@case $map:ident $key:ident $language_tag_opt:ident Option(Seconds($field:ident))) => { - if $field.is_some() { - return Err(serde::de::Error::duplicate_field(stringify!($field))); - } else if let Some(language_tag) = $language_tag_opt { - return Err( - serde::de::Error::custom( - format!( - concat!("unexpected language tag `{}` for key `", stringify!($field), "`"), - language_tag.as_ref() - ) - ) - ); - } - let seconds = $map.next_value::>()?; - $field = seconds.map(Duration::from_secs); - }; - (@case $map:ident $key:ident $language_tag_opt:ident - Option(DateTime(Seconds($field:ident)))) => { - if $field.is_some() { - return Err(serde::de::Error::duplicate_field(stringify!($field))); - } else if let Some(language_tag) = $language_tag_opt { - return Err( - serde::de::Error::custom( - format!( - concat!("unexpected language tag `{}` for key `", stringify!($field), "`"), - language_tag.as_ref() - ) - ) - ); - } - let seconds = $map.next_value::>()?; - $field = seconds - .map(|sec| timestamp_to_utc(&sec).map_err(|_| serde::de::Error::custom( - format!( - concat!( - "failed to parse `{}` as UTC datetime (in seconds) for key `", - stringify!($field), - "`" - ), - sec, - ) - ))).transpose()?; - }; - (@case $map:ident $key:ident $language_tag_opt:ident - Option(Boolean($field:ident))) => { - if $field.is_some() { - return Err(serde::de::Error::duplicate_field(stringify!($field))); - } else if let Some(language_tag) = $language_tag_opt { - return Err( - serde::de::Error::custom( - format!( - concat!("unexpected language tag `{}` for key `", stringify!($field), "`"), - language_tag.as_ref() - ) - ) - ); - } - let boolean = $map.next_value::>()?; - $field = boolean.map(|b| b.0); - }; - (@case $map:ident $key:ident $language_tag_opt:ident Option($field:ident)) => { - if $field.is_some() { - return Err(serde::de::Error::duplicate_field(stringify!($field))); - } else if let Some(language_tag) = $language_tag_opt { - return Err( - serde::de::Error::custom( - format!( - concat!("unexpected language tag `{}` for key `", stringify!($field), "`"), - language_tag.as_ref() - ) - ) - ); - } - $field = $map.next_value()?; - }; - (@case $map:ident $key:ident $language_tag_opt:ident LanguageTag($field:ident)) => { - let localized_claim = - if let Some(ref mut localized_claim) = $field { - localized_claim - } else { - let new = LocalizedClaim::new(); - $field = Some(new); - $field.as_mut().unwrap() - }; - if localized_claim.contains_key($language_tag_opt.as_ref()) { - return Err(serde::de::Error::custom(format!("duplicate field `{}`", $key))); - } - - localized_claim.insert($language_tag_opt, $map.next_value()?); - }; - (@case $map:ident $key:ident $language_tag_opt:ident $field:ident) => { - if $field.is_some() { - return Err(serde::de::Error::duplicate_field(stringify!($field))); - } else if let Some(language_tag) = $language_tag_opt { - return Err( - serde::de::Error::custom( - format!( - concat!("unexpected language tag `{}` for key `", stringify!($field), "`"), - language_tag.as_ref() - ) - ) - ); - } - $field = Some($map.next_value()?); - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [Option(Seconds($field_new:ident))] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: $field_new => $([$($entry)+])* - } - ] - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [Option(DateTime(Seconds($field_new:ident)))] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: $field_new => $([$($entry)+])* - } - ] - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [Option(Boolean($field_new:ident))] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: $field_new => $([$($entry)+])* - } - ] - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [Option($field_new:ident)] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: $field_new => $([$($entry)+])* - } - ] - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [LanguageTag($field_new:ident)] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: $field_new => $([$($entry)+])* - } - ] - }; - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),* => [$field_new:ident] $([$($entry:tt)+])* - }) => { - deserialize_fields![ - @struct_recurs [$($struct_type)+] { - $($name: $e,)* $field_new: - $field_new - .ok_or_else(|| serde::de::Error::missing_field(stringify!($field_new)))? => - $([$($entry)+])* - } - ] - }; - // Actually instantiate the struct. - (@struct_recurs [$($struct_type:tt)+] { - $($name:ident: $e:expr),+ => - }) => { - #[allow(clippy::redundant_field_names)] - $($struct_type)+ { - $($name: $e),+ - } - }; - // Main entry point - ( - $map:ident { - $([$($entry:tt)+])+ - } - ) => { - // let mut field_name = None; - $(deserialize_fields![@let_none $($entry)+];)+ - while let Some(key) = $map.next_key::()? { - let (field_name, language_tag_opt) = split_language_tag_key(&key); - match field_name { - $( - // "field_name" => { ... } - deserialize_fields![@field_str $($entry)+] => { - deserialize_fields![@case $map key language_tag_opt $($entry)+]; - }, - )+ - // Ignore unknown fields. - _ => { - $map.next_value::()?; - } - } - } - Ok(deserialize_fields![@struct_recurs [Self::Value] { => $([$($entry)+])* }]) - }; -} - macro_rules! serialize_fields { (@case $self:ident $map:ident Option(Seconds($field:ident))) => { if let Some(ref $field) = $self.$field { @@ -646,7 +435,7 @@ macro_rules! serialize_fields { }; (@case $self:ident $map:ident Option(DateTime(Seconds($field:ident)))) => { if let Some(ref $field) = $self.$field { - $map.serialize_entry(stringify!($field), &utc_to_seconds(&$field))?; + $map.serialize_entry(stringify!($field), &crate::helpers::Timestamp::from_utc(&$field))?; } }; (@case $self:ident $map:ident Option($field:ident)) => { diff --git a/src/registration/mod.rs b/src/registration/mod.rs index b50efb8..6b67f3d 100644 --- a/src/registration/mod.rs +++ b/src/registration/mod.rs @@ -1,4 +1,4 @@ -use crate::helpers::serde_utc_seconds_opt; +use crate::helpers::{DeserializeMapField, Timestamp}; use crate::http_utils::{auth_bearer, check_content_type, MIME_TYPE_JSON}; use crate::types::localized::split_language_tag_key; use crate::types::{ @@ -21,7 +21,7 @@ use http::status::StatusCode; use serde::de::{DeserializeOwned, Deserializer, MapAccess, Visitor}; use serde::ser::SerializeMap; use serde::{Deserialize, Serialize, Serializer}; -use serde_with::skip_serializing_none; +use serde_with::{serde_as, skip_serializing_none}; use thiserror::Error; use std::fmt::{Debug, Formatter, Result as FormatterResult}; @@ -239,16 +239,7 @@ where K: JsonWebKey, RT: ResponseType, S: SubjectIdentifierType, - >( - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - ); + >(PhantomData<(AT, CA, G, JE, JK, K, RT, S)>); impl<'de, AT, CA, G, JE, JK, K, RT, S> Visitor<'de> for MetadataVisitor where AT: ApplicationType, @@ -271,52 +262,172 @@ where where V: MapAccess<'de>, { - deserialize_fields! { - map { - [redirect_uris] - [Option(response_types)] - [Option(grant_types)] - [Option(application_type)] - [Option(contacts)] - [LanguageTag(client_name)] - [LanguageTag(logo_uri)] - [LanguageTag(client_uri)] - [LanguageTag(policy_uri)] - [LanguageTag(tos_uri)] - [Option(jwks_uri)] - [Option(jwks)] - [Option(sector_identifier_uri)] - [Option(subject_type)] - [Option(id_token_signed_response_alg)] - [Option(id_token_encrypted_response_alg)] - [Option(id_token_encrypted_response_enc)] - [Option(userinfo_signed_response_alg)] - [Option(userinfo_encrypted_response_alg)] - [Option(userinfo_encrypted_response_enc)] - [Option(request_object_signing_alg)] - [Option(request_object_encryption_alg)] - [Option(request_object_encryption_enc)] - [Option(token_endpoint_auth_method)] - [Option(token_endpoint_auth_signing_alg)] - [Option(Seconds(default_max_age))] - [Option(require_auth_time)] - [Option(default_acr_values)] - [Option(initiate_login_uri)] - [Option(request_uris)] - } + // NB: The non-localized fields are actually Option> here so that we can + // distinguish between omitted fields and fields explicitly set to `null`. The + // latter is necessary so that we can detect duplicate fields (e.g., if a key is + // present both with a null value and a non-null value, that's an error). + let mut redirect_uris = None; + let mut response_types = None; + let mut grant_types = None; + let mut application_type = None; + let mut contacts = None; + let mut client_name = None; + let mut logo_uri = None; + let mut client_uri = None; + let mut policy_uri = None; + let mut tos_uri = None; + let mut jwks_uri = None; + let mut jwks = None; + let mut sector_identifier_uri = None; + let mut subject_type = None; + let mut id_token_signed_response_alg = None; + let mut id_token_encrypted_response_alg = None; + let mut id_token_encrypted_response_enc = None; + let mut userinfo_signed_response_alg = None; + let mut userinfo_encrypted_response_alg = None; + let mut userinfo_encrypted_response_enc = None; + let mut request_object_signing_alg = None; + let mut request_object_encryption_alg = None; + let mut request_object_encryption_enc = None; + let mut token_endpoint_auth_method = None; + let mut token_endpoint_auth_signing_alg = None; + let mut default_max_age = None; + let mut require_auth_time = None; + let mut default_acr_values = None; + let mut initiate_login_uri = None; + let mut request_uris = None; + + macro_rules! field_case { + ($field:ident, $typ:ty, $language_tag:ident) => {{ + $field = Some(<$typ>::deserialize_map_field( + &mut map, + stringify!($field), + $language_tag, + $field, + )?); + }}; } + + while let Some(key) = map.next_key::()? { + let (field_name, language_tag) = split_language_tag_key(&key); + match field_name { + "redirect_uris" => field_case!(redirect_uris, Vec<_>, language_tag), + "response_types" => field_case!(response_types, Option<_>, language_tag), + "grant_types" => field_case!(grant_types, Option<_>, language_tag), + "application_type" => { + field_case!(application_type, Option<_>, language_tag) + } + "contacts" => field_case!(contacts, Option<_>, language_tag), + "client_name" => { + field_case!(client_name, LocalizedClaim>, language_tag) + } + "logo_uri" => { + field_case!(logo_uri, LocalizedClaim>, language_tag) + } + "client_uri" => { + field_case!(client_uri, LocalizedClaim>, language_tag) + } + "policy_uri" => { + field_case!(policy_uri, LocalizedClaim>, language_tag) + } + "tos_uri" => field_case!(tos_uri, LocalizedClaim>, language_tag), + "jwks_uri" => field_case!(jwks_uri, Option<_>, language_tag), + "jwks" => field_case!(jwks, Option<_>, language_tag), + "sector_identifier_uri" => { + field_case!(sector_identifier_uri, Option<_>, language_tag) + } + "subject_type" => field_case!(subject_type, Option<_>, language_tag), + "id_token_signed_response_alg" => { + field_case!(id_token_signed_response_alg, Option<_>, language_tag) + } + "id_token_encrypted_response_alg" => { + field_case!(id_token_encrypted_response_alg, Option<_>, language_tag) + } + "id_token_encrypted_response_enc" => { + field_case!(id_token_encrypted_response_enc, Option<_>, language_tag) + } + "userinfo_signed_response_alg" => { + field_case!(userinfo_signed_response_alg, Option<_>, language_tag) + } + "userinfo_encrypted_response_alg" => { + field_case!(userinfo_encrypted_response_alg, Option<_>, language_tag) + } + "userinfo_encrypted_response_enc" => { + field_case!(userinfo_encrypted_response_enc, Option<_>, language_tag) + } + "request_object_signing_alg" => { + field_case!(request_object_signing_alg, Option<_>, language_tag) + } + "request_object_encryption_alg" => { + field_case!(request_object_encryption_alg, Option<_>, language_tag) + } + "request_object_encryption_enc" => { + field_case!(request_object_encryption_enc, Option<_>, language_tag) + } + "token_endpoint_auth_method" => { + field_case!(token_endpoint_auth_method, Option<_>, language_tag) + } + "token_endpoint_auth_signing_alg" => { + field_case!(token_endpoint_auth_signing_alg, Option<_>, language_tag) + } + "default_max_age" => { + field_case!(default_max_age, Option, language_tag) + } + "require_auth_time" => { + field_case!(require_auth_time, Option<_>, language_tag) + } + "default_acr_values" => { + field_case!(default_acr_values, Option<_>, language_tag) + } + "initiate_login_uri" => { + field_case!(initiate_login_uri, Option<_>, language_tag) + } + "request_uris" => field_case!(request_uris, Option<_>, language_tag), + + // Ignore unknown fields. + _ => { + map.next_value::()?; + continue; + } + }; + } + + Ok(StandardClientMetadata { + redirect_uris: redirect_uris + .ok_or_else(|| serde::de::Error::missing_field("redirect_uris"))?, + response_types: response_types.flatten(), + grant_types: grant_types.flatten(), + application_type: application_type.flatten(), + contacts: contacts.flatten(), + client_name: client_name.and_then(LocalizedClaim::flatten_or_none), + logo_uri: logo_uri.and_then(LocalizedClaim::flatten_or_none), + client_uri: client_uri.and_then(LocalizedClaim::flatten_or_none), + policy_uri: policy_uri.and_then(LocalizedClaim::flatten_or_none), + tos_uri: tos_uri.and_then(LocalizedClaim::flatten_or_none), + jwks_uri: jwks_uri.flatten(), + jwks: jwks.flatten(), + sector_identifier_uri: sector_identifier_uri.flatten(), + subject_type: subject_type.flatten(), + id_token_signed_response_alg: id_token_signed_response_alg.flatten(), + id_token_encrypted_response_alg: id_token_encrypted_response_alg.flatten(), + id_token_encrypted_response_enc: id_token_encrypted_response_enc.flatten(), + userinfo_signed_response_alg: userinfo_signed_response_alg.flatten(), + userinfo_encrypted_response_alg: userinfo_encrypted_response_alg.flatten(), + userinfo_encrypted_response_enc: userinfo_encrypted_response_enc.flatten(), + request_object_signing_alg: request_object_signing_alg.flatten(), + request_object_encryption_alg: request_object_encryption_alg.flatten(), + request_object_encryption_enc: request_object_encryption_enc.flatten(), + token_endpoint_auth_method: token_endpoint_auth_method.flatten(), + token_endpoint_auth_signing_alg: token_endpoint_auth_signing_alg.flatten(), + default_max_age: default_max_age.flatten().map(Duration::from_secs), + require_auth_time: require_auth_time.flatten(), + default_acr_values: default_acr_values.flatten(), + initiate_login_uri: initiate_login_uri.flatten(), + request_uris: request_uris.flatten(), + }) } } - deserializer.deserialize_map(MetadataVisitor( - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - PhantomData, - )) + deserializer.deserialize_map(MetadataVisitor(PhantomData)) } } impl Serialize for StandardClientMetadata @@ -622,6 +733,7 @@ pub struct EmptyAdditionalClientRegistrationResponse {} impl AdditionalClientRegistrationResponse for EmptyAdditionalClientRegistrationResponse {} /// Response to a dynamic client registration request. +#[serde_as] #[skip_serializing_none] #[derive(Debug, Deserialize, Serialize)] pub struct ClientRegistrationResponse @@ -643,9 +755,9 @@ where client_secret: Option, registration_access_token: Option, registration_client_uri: Option, - #[serde(with = "serde_utc_seconds_opt", default)] + #[serde_as(as = "Option")] client_id_issued_at: Option>, - #[serde(with = "serde_utc_seconds_opt", default)] + #[serde_as(as = "Option")] client_secret_expires_at: Option>, #[serde(bound = "AC: AdditionalClientMetadata", flatten)] client_metadata: ClientMetadata, diff --git a/src/types/localized.rs b/src/types/localized.rs index ae51daa..2c5946b 100644 --- a/src/types/localized.rs +++ b/src/types/localized.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; +use std::borrow::Cow; use std::collections::HashMap; +use std::fmt::Display; new_type![ /// Language tag adhering to RFC 5646 (e.g., `fr` or `fr-CA`). @@ -12,6 +14,11 @@ impl AsRef for LanguageTag { self } } +impl Display for LanguageTag { + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{}", self.as_ref()) + } +} pub(crate) fn split_language_tag_key(key: &str) -> (&str, Option) { let mut lang_tag_sep = key.splitn(2, '#'); @@ -27,6 +34,17 @@ pub(crate) fn split_language_tag_key(key: &str) -> (&str, Option) { (field_name, language_tag) } +pub(crate) fn join_language_tag_key<'a>( + field_name: &'a str, + language_tag: Option<&LanguageTag>, +) -> Cow<'a, str> { + if let Some(language_tag) = language_tag { + Cow::Owned(format!("{field_name}#{language_tag}")) + } else { + Cow::Borrowed(field_name) + } +} + /// A [locale-aware](https://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsLanguages) /// claim. /// @@ -90,6 +108,23 @@ impl LocalizedClaim { } } } +impl LocalizedClaim> { + pub(crate) fn flatten_or_none(self) -> Option> { + let flattened_tagged = self + .0 + .into_iter() + .filter_map(|(k, v)| v.map(|v| (k, v))) + .collect::>(); + let flattened_default = self.1.flatten(); + + if flattened_tagged.is_empty() && flattened_default.is_none() { + None + } else { + Some(LocalizedClaim(flattened_tagged, flattened_default)) + } + } +} + impl Default for LocalizedClaim { fn default() -> Self { Self(HashMap::new(), None) diff --git a/src/verification/tests.rs b/src/verification/tests.rs index 47fe12f..741353c 100644 --- a/src/verification/tests.rs +++ b/src/verification/tests.rs @@ -4,7 +4,7 @@ use crate::core::{ CoreJwsSigningAlgorithm, CoreRsaPrivateSigningKey, CoreUserInfoClaims, CoreUserInfoJsonWebToken, CoreUserInfoVerifier, }; -use crate::helpers::{timestamp_to_utc, Base64UrlEncodedBytes, Timestamp}; +use crate::helpers::{Base64UrlEncodedBytes, Timestamp}; use crate::jwt::tests::{TEST_RSA_PRIV_KEY, TEST_RSA_PUB_KEY}; use crate::jwt::{ JsonWebToken, JsonWebTokenHeader, JsonWebTokenJsonPayloadSerde, JsonWebTokenType, @@ -814,10 +814,9 @@ fn test_id_token_verified_claims() { CoreJsonWebKeySet::new(vec![rsa_key.clone()]), ) .set_time_fn(|| { - timestamp_to_utc(&Timestamp::Seconds( - mock_current_time.load(Ordering::Relaxed).into(), - )) - .unwrap() + Timestamp::Seconds(mock_current_time.load(Ordering::Relaxed).into()) + .to_utc() + .unwrap() }) .set_issue_time_verifier_fn(|_| { if mock_is_valid_issue_time.load(Ordering::Relaxed) { @@ -829,10 +828,9 @@ fn test_id_token_verified_claims() { let insecure_verifier = CoreIdTokenVerifier::new_insecure_without_verification() .set_time_fn(|| { - timestamp_to_utc(&Timestamp::Seconds( - mock_current_time.load(Ordering::Relaxed).into(), - )) - .unwrap() + Timestamp::Seconds(mock_current_time.load(Ordering::Relaxed).into()) + .to_utc() + .unwrap() }); // This JWTs below have an issue time of 1544928549 and an expiration time of 1544932149. @@ -991,7 +989,7 @@ fn test_id_token_verified_claims() { .set_auth_time_verifier_fn(|auth_time| { assert_eq!( auth_time.unwrap(), - timestamp_to_utc(&Timestamp::Seconds(1544928548.into())).unwrap(), + Timestamp::Seconds(1544928548.into()).to_utc().unwrap(), ); Err("Invalid auth_time claim".to_string()) }) @@ -1054,10 +1052,9 @@ fn test_id_token_verified_claims() { CoreJsonWebKeySet::new(vec![rsa_key.clone()]), ) .set_time_fn(|| { - timestamp_to_utc(&Timestamp::Seconds( - mock_current_time.load(Ordering::Relaxed).into(), - )) - .unwrap() + Timestamp::Seconds(mock_current_time.load(Ordering::Relaxed).into()) + .to_utc() + .unwrap() }); match private_client_verifier.verified_claims(&test_jwt_hs256, &valid_nonce) { Err(ClaimsVerificationError::SignatureVerification(_)) => {} @@ -1092,10 +1089,9 @@ fn test_id_token_verified_claims() { ) .allow_any_alg() .set_time_fn(|| { - timestamp_to_utc(&Timestamp::Seconds( - mock_current_time.load(Ordering::Relaxed).into(), - )) - .unwrap() + Timestamp::Seconds(mock_current_time.load(Ordering::Relaxed).into()) + .to_utc() + .unwrap() }); match private_client_verifier_with_other_secret .verified_claims(&test_jwt_hs256, &valid_nonce) @@ -1166,10 +1162,9 @@ fn test_new_id_token() { let mock_current_time = AtomicUsize::new(1544932148); let time_fn = || { - timestamp_to_utc(&Timestamp::Seconds( - mock_current_time.load(Ordering::Relaxed).into(), - )) - .unwrap() + Timestamp::Seconds(mock_current_time.load(Ordering::Relaxed).into()) + .to_utc() + .unwrap() }; let verifier = CoreIdTokenVerifier::new_public_client( client_id,