Skip to content

Commit

Permalink
Merge branch 'feature/async-shadows' of github.com:BlackbirdHQ/rustot…
Browse files Browse the repository at this point in the history
… into enhancement/async_shadows_mutex
  • Loading branch information
KennethKnudsen97 committed Sep 25, 2024
2 parents b03624e + 8b4a21a commit eda4a75
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 38 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true }
serde-json-core = { version = "0.5" }
shadow-derive = { path = "shadow_derive", version = "0.2.1" }
embedded-storage-async = "0.4"
# embedded-mqtt = { git = "ssh://[email protected]/FactbirdHQ/embedded-mqtt.git", rev = "d75bda6" }
embedded-mqtt = { path = "../embedded-mqtt" }
embedded-mqtt = { git = "ssh://[email protected]/FactbirdHQ/embedded-mqtt.git", rev = "d2b7c02" }

futures = { version = "0.3.28", default-features = false }

Expand Down
15 changes: 8 additions & 7 deletions src/jobs/data_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pub enum JobStatus {
Removed,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ErrorCode {
/// The request was sent to a topic in the AWS IoT Jobs namespace that does
/// not map to any API.
Expand Down Expand Up @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Contains data about a job execution.
Expand Down Expand Up @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \
Expand All @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Sent whenever a job execution is added to or removed from the list of
Expand Down Expand Up @@ -289,7 +290,7 @@ pub struct Jobs {
/// service operation.
#[derive(Debug, PartialEq, Deserialize)]
pub struct ErrorResponse<'a> {
code: ErrorCode,
pub code: ErrorCode,
/// An error message string.
message: &'a str,
/// A client token used to correlate requests and responses. Enter an
Expand Down Expand Up @@ -394,7 +395,7 @@ mod test {
in_progress_jobs: Some(Vec::<JobExecutionSummary, MAX_RUNNING_JOBS>::new()),
queued_jobs: None,
timestamp: 1587381778,
client_token: "0:client_name",
client_token: Some("0:client_name"),
}
);

Expand Down Expand Up @@ -433,7 +434,7 @@ mod test {
in_progress_jobs: Some(Vec::<JobExecutionSummary, MAX_RUNNING_JOBS>::new()),
queued_jobs: Some(queued_jobs),
timestamp: 1587381778,
client_token: "0:client_name",
client_token: Some("0:client_name"),
}
);
}
Expand Down
83 changes: 76 additions & 7 deletions src/ota/control_interface/mqtt.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use core::fmt::Write;

use bitmaps::{Bits, BitsImpl};
use embassy_sync::blocking_mutex::raw::RawMutex;
use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS};
use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic};
use futures::StreamExt as _;

use super::ControlInterface;
use crate::jobs::data_types::JobStatus;
use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN};
use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse};
use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN};
use crate::ota::config::Config;
use crate::ota::encoding::json::JobStatusReason;
use crate::ota::encoding::FileContext;
use crate::ota::encoding::{self, FileContext};
use crate::ota::error::OtaError;

impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
for embedded_mqtt::MqttClient<'a, M, SUBS>
impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS>
where
BitsImpl<{ SUBS }>: Bits,
{
/// Check for next available OTA job from the job service by publishing a
/// "get next job" message to the job service.
Expand Down Expand Up @@ -90,11 +93,39 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
}
}

let mut sub = self
.subscribe::<2>(
Subscribe::builder()
.topics(&[
SubscribeTopic::builder()
.topic_path(
JobTopic::UpdateAccepted(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>(
self.client_id(),
)?
.as_str(),
)
.build(),
SubscribeTopic::builder()
.topic_path(
JobTopic::UpdateRejected(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>(
self.client_id(),
)?
.as_str(),
)
.build(),
])
.build(),
)
.await?;

let topic = JobTopic::Update(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?;
let payload = DeferredPayload::new(
|buf| {
Jobs::update(status)
.client_token(self.client_id())
.status_details(&file_ctx.status_details)
.payload(buf)
.map_err(|_| EncodingError::BufferSize)
Expand All @@ -111,6 +142,44 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
)
.await?;

Ok(())
loop {
let message = sub.next().await.ok_or(JobError::Encoding)?;

// Check if topic is GetAccepted
match crate::jobs::Topic::from_str(message.topic_name()) {
Some(crate::jobs::Topic::UpdateAccepted(_)) => {
// Check client token
let (response, _) = serde_json_core::from_slice::<
UpdateJobExecutionResponse<encoding::json::OtaJob<'_>>,
>(message.payload())
.map_err(|_| JobError::Encoding)?;

if response.client_token != Some(self.client_id()) {
error!(
"Unexpected client token received: {}, expected: {}",
response.client_token.unwrap_or("None"),
self.client_id()
);
continue;
}

return Ok(());
}
Some(crate::jobs::Topic::UpdateRejected(_)) => {
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(message.payload())
.map_err(|_| JobError::Encoding)?;

if error_response.client_token != Some(self.client_id()) {
continue;
}

return Err(OtaError::UpdateRejected(error_response.code));
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
}
}
}
}
}
11 changes: 9 additions & 2 deletions src/ota/data_interface/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::fmt::{Display, Write};
use core::ops::DerefMut;
use core::str::FromStr;

use bitmaps::{Bits, BitsImpl};
use embassy_sync::blocking_mutex::raw::RawMutex;
use embedded_mqtt::{
DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription,
Expand Down Expand Up @@ -123,13 +124,19 @@ impl<'a> OtaTopic<'a> {
}
}

impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> {
impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1>
where
BitsImpl<{ SUBS }>: Bits,
{
async fn next_block(&mut self) -> Result<Option<impl DerefMut<Target = [u8]>>, OtaError> {
Ok(self.next().await)
}
}

impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> {
impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS>
where
BitsImpl<{ SUBS }>: Bits,
{
const PROTOCOL: Protocol = Protocol::Mqtt;

type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't;
Expand Down
5 changes: 3 additions & 2 deletions src/ota/error.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use crate::jobs::JobError;
use crate::jobs::{data_types::ErrorCode, JobError};

use super::pal::OtaPalError;

#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum OtaError {
NoActiveJob,
SignalEventFailed,
Momentum,
MomentumAbort,
InvalidInterface,
ResetFailed,
BlockOutOfRange,
ZeroFileSize,
Overflow,
UnexpectedTopic,
InvalidFile,
UpdateRejected(ErrorCode),
Write(
#[cfg_attr(feature = "defmt", defmt(Debug2Format))]
embedded_storage_async::nor_flash::NorFlashErrorKind,
Expand Down
59 changes: 41 additions & 18 deletions src/provisioning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ pub mod topics;

use core::future::Future;

use bitmaps::{Bits, BitsImpl};
use embassy_sync::blocking_mutex::raw::RawMutex;
use embedded_mqtt::{
DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, Subscription,
BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic,
Subscription,
};
use futures::StreamExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
Expand Down Expand Up @@ -46,6 +48,7 @@ impl FleetProvisioner {
credential_handler: &mut impl CredentialHandler,
) -> Result<Option<C>, Error>
where
BitsImpl<{ SUBS }>: Bits,
C: DeserializeOwned,
{
Self::provision_inner(
Expand All @@ -67,6 +70,7 @@ impl FleetProvisioner {
credential_handler: &mut impl CredentialHandler,
) -> Result<Option<C>, Error>
where
BitsImpl<{ SUBS }>: Bits,
C: DeserializeOwned,
{
Self::provision_inner(
Expand All @@ -88,6 +92,7 @@ impl FleetProvisioner {
credential_handler: &mut impl CredentialHandler,
) -> Result<Option<C>, Error>
where
BitsImpl<{ SUBS }>: Bits,
C: DeserializeOwned,
{
Self::provision_inner(
Expand All @@ -110,6 +115,7 @@ impl FleetProvisioner {
credential_handler: &mut impl CredentialHandler,
) -> Result<Option<C>, Error>
where
BitsImpl<{ SUBS }>: Bits,
C: DeserializeOwned,
{
Self::provision_inner(
Expand All @@ -133,8 +139,11 @@ impl FleetProvisioner {
payload_format: PayloadFormat,
) -> Result<Option<C>, Error>
where
BitsImpl<{ SUBS }>: Bits,
C: DeserializeOwned,
{
use embedded_mqtt::SliceBufferProvider;

let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?;
let mut message = create_subscription
.next()
Expand All @@ -143,10 +152,11 @@ impl FleetProvisioner {

let ownership_token = match Topic::from_str(message.topic_name()) {
Some(Topic::CreateKeysAndCertificateAccepted(format)) => {
let response = Self::deserialize::<CreateKeysAndCertificateResponse, SUBS>(
format,
&mut message,
)?;
let response = Self::deserialize::<
CreateKeysAndCertificateResponse,
SliceBufferProvider<'a>,
SUBS,
>(format, &mut message)?;

credential_handler
.store_credentials(Credentials {
Expand All @@ -160,10 +170,11 @@ impl FleetProvisioner {
}

Some(Topic::CreateCertificateFromCsrAccepted(format)) => {
let response = Self::deserialize::<CreateCertificateFromCsrResponse, SUBS>(
format,
&mut message,
)?;
let response = Self::deserialize::<
CreateCertificateFromCsrResponse,
SliceBufferProvider<'a>,
SUBS,
>(format, &mut message)?;

credential_handler
.store_credentials(Credentials {
Expand Down Expand Up @@ -253,8 +264,11 @@ impl FleetProvisioner {

match Topic::from_str(message.topic_name()) {
Some(Topic::RegisterThingAccepted(_, format)) => {
let response =
Self::deserialize::<RegisterThingResponse<'_, C>, SUBS>(format, &mut message)?;
let response = Self::deserialize::<
RegisterThingResponse<'_, C>,
SliceBufferProvider<'a>,
SUBS,
>(format, &mut message)?;

Ok(response.device_configuration)
}
Expand All @@ -276,7 +290,10 @@ impl FleetProvisioner {
mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>,
csr: Option<&str>,
payload_format: PayloadFormat,
) -> Result<Subscription<'a, 'b, M, SUBS, 2>, Error> {
) -> Result<Subscription<'a, 'b, M, SUBS, 2>, Error>
where
BitsImpl<{ SUBS }>: Bits,
{
if let Some(csr) = csr {
let subscription = mqtt
.subscribe(
Expand Down Expand Up @@ -378,10 +395,13 @@ impl FleetProvisioner {
}
}

fn deserialize<'a, R: Deserialize<'a>, const SUBS: usize>(
fn deserialize<'a, R: Deserialize<'a>, B: BufferProvider, const SUBS: usize>(
payload_format: PayloadFormat,
message: &'a mut Message<'_, SUBS>,
) -> Result<R, Error> {
message: &'a mut Message<'_, B, SUBS>,
) -> Result<R, Error>
where
BitsImpl<{ SUBS }>: Bits,
{
trace!(
"Accepted Topic {:?}. Payload len: {:?}",
payload_format,
Expand All @@ -395,10 +415,13 @@ impl FleetProvisioner {
})
}

fn handle_error<const SUBS: usize>(
fn handle_error<B: BufferProvider, const SUBS: usize>(
format: PayloadFormat,
mut message: Message<'_, SUBS>,
) -> Result<(), Error> {
mut message: Message<'_, B, SUBS>,
) -> Result<(), Error>
where
BitsImpl<{ SUBS }>: Bits,
{
error!(">> {:?}", message.topic_name());

let response = match format {
Expand Down
Loading

0 comments on commit eda4a75

Please sign in to comment.