diff --git a/onvif/Cargo.toml b/onvif/Cargo.toml index 28b22a5..602e803 100644 --- a/onvif/Cargo.toml +++ b/onvif/Cargo.toml @@ -19,6 +19,7 @@ futures = "0.3" futures-core = "0.3" futures-util = "0.3" num-bigint = "0.4" +nonzero_ext = "0.3" reqwest = { version = "0.12", default-features = false } schema = { version = "0.1.0", path = "../schema", default-features = false, features = ["analytics", "devicemgmt", "event", "media", "ptz"] } sha1 = "0.6" diff --git a/onvif/src/soap/auth/digest.rs b/onvif/src/soap/auth/digest.rs index 425d772..91f39c5 100644 --- a/onvif/src/soap/auth/digest.rs +++ b/onvif/src/soap/auth/digest.rs @@ -1,6 +1,8 @@ use crate::soap::client::Credentials; +use nonzero_ext::nonzero; use reqwest::{RequestBuilder, Response}; use std::fmt::{Debug, Formatter}; +use std::num::NonZeroU8; use thiserror::Error; use url::Url; @@ -22,8 +24,10 @@ pub struct Digest { enum State { Default, - Got401(reqwest::Response), - Got401Twice, + Got401 { + response: Response, + count: NonZeroU8, + }, } impl Digest { @@ -37,29 +41,55 @@ impl Digest { } impl Digest { + /// Call this when the authentication was successful. + pub fn set_success(&mut self) { + if let State::Got401 { count, .. } = &mut self.state { + // We always store at least one request, so it's never zero. + *count = nonzero!(1_u8); + } + } + + /// Call this when received 401 Unauthorized. pub fn set_401(&mut self, response: Response) { - match self.state { - State::Default => self.state = State::Got401(response), - State::Got401(_) => self.state = State::Got401Twice, - State::Got401Twice => {} + self.state = match self.state { + State::Default => State::Got401 { + response, + count: nonzero!(1_u8), + }, + State::Got401 { count, .. } => State::Got401 { + response, + count: count.saturating_add(1), + }, } } pub fn is_failed(&self) -> bool { - matches!(self.state, State::Got401Twice) + match &self.state { + State::Default => false, + // Possible scenarios: + // - We've got 401 with a challenge for the first time, we calculate the answer, then + // we get 200 OK. So, a single 401 is never a failure. + // - After successful auth the count is 1 because we always store at least one request, + // and the caller decided to reuse the same challenge for multiple requests. But at + // some point, we'll get a 401 with a new challenge and `stale=true`. + // So, we'll get a second 401, and this is also not a failure because after + // calculating the answer to the challenge, we'll get a 200 OK, and will reset the + // counter in `set_success()`. + // - Three 401's in a row is certainly a failure. + State::Got401 { count, .. } => count.get() >= 3, + } } pub fn add_headers(&self, mut request: RequestBuilder) -> Result { match &self.state { State::Default => Ok(request), - State::Got401(response) => { + State::Got401 { response, .. } => { let creds = self.creds.as_ref().ok_or(Error::NoCredentials)?; request = request.header("Authorization", digest_auth(response, creds, &self.uri)?); Ok(request) } - State::Got401Twice => Err(Error::InvalidState), } } } @@ -94,10 +124,11 @@ impl Debug for Digest { impl Debug for State { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - State::Default => "FirstRequest", - State::Got401(_) => "Got401", - State::Got401Twice => "Got401Twice", - }) + match self { + State::Default => write!(f, "FirstRequest")?, + State::Got401 { count, .. } => write!(f, "Got401({count})")?, + }; + + Ok(()) } } diff --git a/onvif/src/soap/client.rs b/onvif/src/soap/client.rs index 02f4fd0..5daa919 100644 --- a/onvif/src/soap/client.rs +++ b/onvif/src/soap/client.rs @@ -6,7 +6,9 @@ use crate::soap::{ }; use async_recursion::async_recursion; use async_trait::async_trait; +use futures_util::lock::Mutex; use schema::transport::{Error, Transport}; +use std::ops::DerefMut; use std::{ fmt::{Debug, Formatter}, sync::Arc, @@ -19,6 +21,7 @@ use url::Url; pub struct Client { client: reqwest::Client, config: Config, + digest_auth_state: Arc>, } #[derive(Clone)] @@ -84,9 +87,12 @@ impl ClientBuilder { .unwrap() }; + let digest = Digest::new(&self.config.uri, &self.config.credentials); + Client { client, config: self.config, + digest_auth_state: Arc::new(Mutex::new(digest)), } } @@ -144,8 +150,8 @@ impl Debug for Credentials { pub type ResponsePatcher = Arc Result + Send + Sync>; #[derive(Debug)] -enum RequestAuthType { - Digest(Digest), +enum RequestAuthType<'a> { + Digest(&'a mut Digest), UsernameToken, } @@ -172,8 +178,8 @@ impl Transport for Client { impl Client { async fn request_with_digest(&self, message: &str) -> Result { - let mut auth_type = - RequestAuthType::Digest(Digest::new(&self.config.uri, &self.config.credentials)); + let mut guard = self.digest_auth_state.lock().await; + let mut auth_type = RequestAuthType::Digest(guard.deref_mut()); self.request_recursive(message, &self.config.uri, &mut auth_type, 0) .await @@ -232,6 +238,10 @@ impl Client { debug!("Response status: {status}"); if status.is_success() { + if let RequestAuthType::Digest(digest) = auth_type { + digest.set_success(); + } + response .text() .await