From ec47aedb1a0c7e9eecfa7a06032c50551c93d588 Mon Sep 17 00:00:00 2001 From: karthik2804 Date: Mon, 5 Aug 2024 08:44:27 +0200 Subject: [PATCH] Add outbound MQTT factor Co-authored-by: rylev Signed-off-by: karthik2804 --- Cargo.lock | 17 +++ crates/factor-outbound-mqtt/Cargo.toml | 24 ++++ crates/factor-outbound-mqtt/src/host.rs | 131 ++++++++++++++++++ crates/factor-outbound-mqtt/src/lib.rs | 128 +++++++++++++++++ .../factor-outbound-mqtt/tests/factor_test.rs | 119 ++++++++++++++++ 5 files changed, 419 insertions(+) create mode 100644 crates/factor-outbound-mqtt/Cargo.toml create mode 100644 crates/factor-outbound-mqtt/src/host.rs create mode 100644 crates/factor-outbound-mqtt/src/lib.rs create mode 100644 crates/factor-outbound-mqtt/tests/factor_test.rs diff --git a/Cargo.lock b/Cargo.lock index e66bbf0d0..32d7e70fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7715,6 +7715,23 @@ dependencies = [ "wasmtime-wasi-http", ] +[[package]] +name = "spin-factor-outbound-mqtt" +version = "2.7.0-pre0" +dependencies = [ + "anyhow", + "rumqttc", + "spin-core", + "spin-factor-outbound-networking", + "spin-factor-variables", + "spin-factors", + "spin-factors-test", + "spin-world", + "table", + "tokio", + "tracing", +] + [[package]] name = "spin-factor-outbound-mysql" version = "2.7.0-pre0" diff --git a/crates/factor-outbound-mqtt/Cargo.toml b/crates/factor-outbound-mqtt/Cargo.toml new file mode 100644 index 000000000..76c44511e --- /dev/null +++ b/crates/factor-outbound-mqtt/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "spin-factor-outbound-mqtt" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } + +[dependencies] +anyhow = "1.0" +rumqttc = { version = "0.24", features = ["url"] } +spin-factor-outbound-networking = { path = "../factor-outbound-networking" } +spin-factors = { path = "../factors" } +spin-core = { path = "../core" } +spin-world = { path = "../world" } +tracing = { workspace = true } +table = { path = "../table" } +tokio = { version = "1.0", features = ["sync"] } + +[dev-dependencies] +spin-factor-variables = { path = "../factor-variables" } +spin-factors-test = { path = "../factors-test" } +tokio = { version = "1", features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/crates/factor-outbound-mqtt/src/host.rs b/crates/factor-outbound-mqtt/src/host.rs new file mode 100644 index 000000000..a6d0a1b0c --- /dev/null +++ b/crates/factor-outbound-mqtt/src/host.rs @@ -0,0 +1,131 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use spin_core::{async_trait, wasmtime::component::Resource}; +use spin_factor_outbound_networking::OutboundAllowedHosts; +use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos}; +use tracing::{instrument, Level}; + +pub type CreateClient = Arc< + dyn Fn(String, String, String, Duration) -> Result, Error> + Send + Sync, +>; + +pub struct InstanceState { + pub allowed_hosts: OutboundAllowedHosts, + pub connections: table::Table>, + pub create_client: CreateClient, +} + +impl InstanceState { + pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: CreateClient) -> Self { + Self { + allowed_hosts, + create_client, + connections: table::Table::new(1024), + } + } +} + +#[async_trait] +pub trait MqttClient: Send + Sync { + async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec) -> Result<(), Error>; +} + +impl InstanceState { + async fn is_address_allowed(&self, address: &str) -> Result { + self.allowed_hosts.check_url(address, "mqtt").await + } + + async fn establish_connection( + &mut self, + address: String, + username: String, + password: String, + keep_alive_interval: Duration, + ) -> Result, Error> { + self.connections + .push((self.create_client)( + address, + username, + password, + keep_alive_interval, + )?) + .map(Resource::new_own) + .map_err(|_| Error::TooManyConnections) + } + + async fn get_conn(&self, connection: Resource) -> Result<&dyn MqttClient, Error> { + self.connections + .get(connection.rep()) + .ok_or(Error::Other( + "could not find connection for resource".into(), + )) + .map(|c| c.as_ref()) + } +} + +impl v2::Host for InstanceState { + fn convert_error(&mut self, error: Error) -> Result { + Ok(error) + } +} + +#[async_trait] +impl v2::HostConnection for InstanceState { + #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))] + async fn open( + &mut self, + address: String, + username: String, + password: String, + keep_alive_interval: u64, + ) -> Result, Error> { + if !self + .is_address_allowed(&address) + .await + .map_err(|e| v2::Error::Other(e.to_string()))? + { + return Err(v2::Error::ConnectionFailed(format!( + "address {address} is not permitted" + ))); + } + self.establish_connection( + address, + username, + password, + Duration::from_secs(keep_alive_interval), + ) + .await + } + + /// Publish a message to the MQTT broker. + /// + /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the + /// current trace context into the payload yourself. + /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation. + #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO), + fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish", + messaging.system = "mqtt"))] + async fn publish( + &mut self, + connection: Resource, + topic: String, + payload: Vec, + qos: Qos, + ) -> Result<(), Error> { + let conn = self.get_conn(connection).await.map_err(other_error)?; + + conn.publish_bytes(topic, qos, payload).await?; + + Ok(()) + } + + fn drop(&mut self, connection: Resource) -> anyhow::Result<()> { + self.connections.remove(connection.rep()); + Ok(()) + } +} + +pub fn other_error(e: impl std::fmt::Display) -> Error { + Error::Other(e.to_string()) +} diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs new file mode 100644 index 000000000..db63318db --- /dev/null +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -0,0 +1,128 @@ +mod host; + +use std::time::Duration; + +use host::other_error; +use host::CreateClient; +use host::InstanceState; +use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; +use spin_core::async_trait; +use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factors::{ + ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, + SelfInstanceBuilder, +}; +use spin_world::v2::mqtt::{self as v2, Error, Qos}; +use tokio::sync::Mutex; + +pub use host::MqttClient; + +pub struct OutboundMqttFactor { + create_client: CreateClient, +} + +impl OutboundMqttFactor { + pub fn new(create_client: CreateClient) -> Self { + Self { create_client } + } +} + +impl Factor for OutboundMqttFactor { + type RuntimeConfig = (); + type AppState = (); + type InstanceBuilder = InstanceState; + + fn init( + &mut self, + mut ctx: spin_factors::InitContext, + ) -> anyhow::Result<()> { + ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?; + Ok(()) + } + + fn configure_app( + &self, + _ctx: ConfigureAppContext, + ) -> anyhow::Result { + Ok(()) + } + + fn prepare( + &self, + _ctx: PrepareContext, + builders: &mut InstanceBuilders, + ) -> anyhow::Result { + let allowed_hosts = builders + .get_mut::()? + .allowed_hosts(); + Ok(InstanceState::new( + allowed_hosts, + self.create_client.clone(), + )) + } +} + +impl SelfInstanceBuilder for InstanceState {} + +pub struct NetworkedMqttClient { + inner: rumqttc::AsyncClient, + event_loop: Mutex, +} + +const MQTT_CHANNEL_CAP: usize = 1000; + +impl NetworkedMqttClient { + pub fn create( + address: String, + username: String, + password: String, + keep_alive_interval: Duration, + ) -> Result { + let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| { + tracing::error!("MQTT URL parse error: {e:?}"); + Error::InvalidAddress + })?; + conn_opts.set_credentials(username, password); + conn_opts.set_keep_alive(keep_alive_interval); + let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP); + Ok(Self { + inner: client, + event_loop: Mutex::new(event_loop), + }) + } +} + +#[async_trait] +impl MqttClient for NetworkedMqttClient { + async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec) -> Result<(), Error> { + let qos = match qos { + Qos::AtMostOnce => rumqttc::QoS::AtMostOnce, + Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce, + Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce, + }; + // Message published to EventLoop (not MQTT Broker) + self.inner + .publish_bytes(topic, qos, false, payload.into()) + .await + .map_err(other_error)?; + + // Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error. + // We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool. + let mut lock = self.event_loop.lock().await; + loop { + let event = lock + .poll() + .await + .map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?; + + match (qos, event) { + (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_))) + | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_))) + | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break, + + (_, _) => continue, + } + } + Ok(()) + } +} diff --git a/crates/factor-outbound-mqtt/tests/factor_test.rs b/crates/factor-outbound-mqtt/tests/factor_test.rs new file mode 100644 index 000000000..178d17a0e --- /dev/null +++ b/crates/factor-outbound-mqtt/tests/factor_test.rs @@ -0,0 +1,119 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use spin_core::async_trait; +use spin_factor_outbound_mqtt::{MqttClient, OutboundMqttFactor}; +use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_variables::VariablesFactor; +use spin_factors::{anyhow, RuntimeFactors}; +use spin_factors_test::{toml, TestEnvironment}; +use spin_world::v2::mqtt::{self as v2, Error, HostConnection, Qos}; + +pub struct MockMqttClient {} + +#[async_trait] +impl MqttClient for MockMqttClient { + async fn publish_bytes( + &self, + _topic: String, + _qos: Qos, + _payload: Vec, + ) -> Result<(), Error> { + Ok(()) + } +} + +#[derive(RuntimeFactors)] +struct TestFactors { + variables: VariablesFactor, + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor, +} + +fn factors() -> TestFactors { + TestFactors { + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor::new(Arc::new(|_, _, _, _| Ok(Box::new(MockMqttClient {})))), + } +} + +fn test_env() -> TestEnvironment { + TestEnvironment::new(factors()).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*"] + }) +} + +#[tokio::test] +async fn disallowed_host_fails() -> anyhow::Result<()> { + let env = TestEnvironment::new(factors()).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + }); + let mut state = env.build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await; + let Err(err) = res else { + bail!("expected Err, got Ok"); + }; + assert!(matches!(err, v2::Error::ConnectionFailed(_))); + + Ok(()) +} + +#[tokio::test] +async fn allowed_host_succeeds() -> anyhow::Result<()> { + let mut state = test_env().build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await; + let Ok(_) = res else { + bail!("expected Ok, got Err"); + }; + + Ok(()) +} + +#[tokio::test] +async fn exercise_publish() -> anyhow::Result<()> { + let mut state = test_env().build_instance_state().await?; + + let res = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + state + .mqtt + .publish( + res, + "message".to_string(), + b"test message".to_vec(), + Qos::ExactlyOnce, + ) + .await?; + + Ok(()) +}