diff --git a/Cargo.toml b/Cargo.toml index e23ed37..01f3103 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,8 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -# embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt.git", rev = "d75bda6" } -embedded-mqtt = { path = "../embedded-mqtt" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt.git", rev = "d2b7c02" } futures = { version = "0.3.28", default-features = false } diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 5910469..36449dd 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -22,7 +22,8 @@ pub enum JobStatus { Removed, } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ErrorCode { /// The request was sent to a topic in the AWS IoT Jobs namespace that does /// not map to any API. @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Contains data about a job execution. @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \ @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Sent whenever a job execution is added to or removed from the list of @@ -289,7 +290,7 @@ pub struct Jobs { /// service operation. #[derive(Debug, PartialEq, Deserialize)] pub struct ErrorResponse<'a> { - code: ErrorCode, + pub code: ErrorCode, /// An error message string. message: &'a str, /// A client token used to correlate requests and responses. Enter an @@ -394,7 +395,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: None, timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); @@ -433,7 +434,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: Some(queued_jobs), timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); } diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 291f8ff..4456428 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,18 +1,21 @@ use core::fmt::Write; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; -use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; +use futures::StreamExt as _; use super::ControlInterface; -use crate::jobs::data_types::JobStatus; -use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; +use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; +use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; -use crate::ota::encoding::FileContext; +use crate::ota::encoding::{self, FileContext}; use crate::ota::error::OtaError; -impl<'a, M: RawMutex, const SUBS: usize> ControlInterface - for embedded_mqtt::MqttClient<'a, M, SUBS> +impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, { /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. @@ -90,11 +93,39 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } } + let mut sub = self + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateRejected(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + let topic = JobTopic::Update(file_ctx.job_name.as_str()) .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; let payload = DeferredPayload::new( |buf| { Jobs::update(status) + .client_token(self.client_id()) .status_details(&file_ctx.status_details) .payload(buf) .map_err(|_| EncodingError::BufferSize) @@ -111,6 +142,44 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface ) .await?; - Ok(()) + loop { + let message = sub.next().await.ok_or(JobError::Encoding)?; + + // Check if topic is GetAccepted + match crate::jobs::Topic::from_str(message.topic_name()) { + Some(crate::jobs::Topic::UpdateAccepted(_)) => { + // Check client token + let (response, _) = serde_json_core::from_slice::< + UpdateJobExecutionResponse>, + >(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if response.client_token != Some(self.client_id()) { + error!( + "Unexpected client token received: {}, expected: {}", + response.client_token.unwrap_or("None"), + self.client_id() + ); + continue; + } + + return Ok(()); + } + Some(crate::jobs::Topic::UpdateRejected(_)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if error_response.client_token != Some(self.client_id()) { + continue; + } + + return Err(OtaError::UpdateRejected(error_response.code)); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + } + } + } } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index edfdcd1..17bbf48 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -2,6 +2,7 @@ use core::fmt::{Display, Write}; use core::ops::DerefMut; use core::str::FromStr; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription, @@ -123,13 +124,19 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> { +impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> +where + BitsImpl<{ SUBS }>: Bits, +{ async fn next_block(&mut self) -> Result>, OtaError> { Ok(self.next().await) } } -impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> { +impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, +{ const PROTOCOL: Protocol = Protocol::Mqtt; type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't; diff --git a/src/ota/error.rs b/src/ota/error.rs index 8c5744b..119cc20 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -1,4 +1,4 @@ -use crate::jobs::JobError; +use crate::jobs::{data_types::ErrorCode, JobError}; use super::pal::OtaPalError; @@ -6,7 +6,6 @@ use super::pal::OtaPalError; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum OtaError { NoActiveJob, - SignalEventFailed, Momentum, MomentumAbort, InvalidInterface, @@ -14,7 +13,9 @@ pub enum OtaError { BlockOutOfRange, ZeroFileSize, Overflow, + UnexpectedTopic, InvalidFile, + UpdateRejected(ErrorCode), Write( #[cfg_attr(feature = "defmt", defmt(Debug2Format))] embedded_storage_async::nor_flash::NorFlashErrorKind, diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 0c38f74..d86e2a7 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -4,9 +4,11 @@ pub mod topics; use core::future::Future; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, Subscription, + BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, + Subscription, }; use futures::StreamExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -46,6 +48,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -67,6 +70,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -88,6 +92,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -110,6 +115,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -133,8 +139,11 @@ impl FleetProvisioner { payload_format: PayloadFormat, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { + use embedded_mqtt::SliceBufferProvider; + let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription .next() @@ -143,10 +152,11 @@ impl FleetProvisioner { let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateKeysAndCertificateResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -160,10 +170,11 @@ impl FleetProvisioner { } Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateCertificateFromCsrResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -253,8 +264,11 @@ impl FleetProvisioner { match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { - let response = - Self::deserialize::, SUBS>(format, &mut message)?; + let response = Self::deserialize::< + RegisterThingResponse<'_, C>, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; Ok(response.device_configuration) } @@ -276,7 +290,10 @@ impl FleetProvisioner { mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { + ) -> Result, Error> + where + BitsImpl<{ SUBS }>: Bits, + { if let Some(csr) = csr { let subscription = mqtt .subscribe( @@ -378,10 +395,13 @@ impl FleetProvisioner { } } - fn deserialize<'a, R: Deserialize<'a>, const SUBS: usize>( + fn deserialize<'a, R: Deserialize<'a>, B: BufferProvider, const SUBS: usize>( payload_format: PayloadFormat, - message: &'a mut Message<'_, SUBS>, - ) -> Result { + message: &'a mut Message<'_, B, SUBS>, + ) -> Result + where + BitsImpl<{ SUBS }>: Bits, + { trace!( "Accepted Topic {:?}. Payload len: {:?}", payload_format, @@ -395,10 +415,13 @@ impl FleetProvisioner { }) } - fn handle_error( + fn handle_error( format: PayloadFormat, - mut message: Message<'_, SUBS>, - ) -> Result<(), Error> { + mut message: Message<'_, B, SUBS>, + ) -> Result<(), Error> + where + BitsImpl<{ SUBS }>: Bits, + { error!(">> {:?}", message.topic_name()); let response = match format { diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 2b4c8a2..2825fd5 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -6,6 +6,7 @@ pub mod topics; use core::{marker::PhantomData, ops::DerefMut}; +use bitmaps::{Bits, BitsImpl}; pub use data_types::Patch; use embassy_sync::{ blocking_mutex::raw::{NoopRawMutex, RawMutex}, @@ -35,6 +36,7 @@ pub trait ShadowState: ShadowPatch + Default { struct ShadowHandler<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, @@ -44,6 +46,7 @@ where impl<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> ShadowHandler<'a, 'm, M, S, SUBS> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { async fn handle_delta(&self) -> Result, Error> { @@ -378,6 +381,7 @@ where pub struct PersistedShadow<'a, 'm, S: ShadowState, M: RawMutex, D: ShadowDAO, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { handler: ShadowHandler<'a, 'm, M, S, SUBS>, @@ -386,6 +390,7 @@ where impl<'a, 'm, S, M, D, const SUBS: usize> PersistedShadow<'a, 'm, S, M, D, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + Default, M: RawMutex, D: ShadowDAO, @@ -495,6 +500,7 @@ where pub struct Shadow<'a, 'm, S: ShadowState, M: RawMutex, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { state: S, @@ -503,6 +509,7 @@ where impl<'a, 'm, S, M, const SUBS: usize> Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, @@ -583,6 +590,7 @@ where impl<'a, 'm, S, M, const SUBS: usize> core::fmt::Debug for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + core::fmt::Debug, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, @@ -600,6 +608,7 @@ where #[cfg(feature = "defmt")] impl<'a, 'm, S, M, const SUBS: usize> defmt::Format for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + defmt::Format, M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:,