Skip to content

Commit

Permalink
use custom trait object instead of closure trait
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 authored and lann committed Aug 19, 2024
1 parent ec47aed commit 398cd3f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 20 deletions.
2 changes: 1 addition & 1 deletion crates/factor-outbound-mqtt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-core = { path = "../core" }
spin-world = { path = "../world" }
tracing = { workspace = true }
table = { path = "../table" }
tokio = { version = "1.0", features = ["sync"] }
tracing = { workspace = true }

[dev-dependencies]
spin-factor-variables = { path = "../factor-variables" }
Expand Down
28 changes: 15 additions & 13 deletions crates/factor-outbound-mqtt/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@ use spin_factor_outbound_networking::OutboundAllowedHosts;
use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
use tracing::{instrument, Level};

pub type CreateClient = Arc<
dyn Fn(String, String, String, Duration) -> Result<Box<dyn MqttClient>, Error> + Send + Sync,
>;
#[async_trait]
pub trait ClientCreator: Send + Sync {
fn create(
&self,
address: String,
username: String,
password: String,
keep_alive_interval: Duration,
) -> Result<Arc<dyn MqttClient>, Error>;
}

pub struct InstanceState {
pub allowed_hosts: OutboundAllowedHosts,
pub connections: table::Table<Box<dyn MqttClient>>,
pub create_client: CreateClient,
allowed_hosts: OutboundAllowedHosts,
connections: table::Table<Arc<dyn MqttClient>>,
create_client: Arc<dyn ClientCreator>,
}

impl InstanceState {
pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: CreateClient) -> Self {
pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
Self {
allowed_hosts,
create_client,
Expand All @@ -44,12 +51,7 @@ impl InstanceState {
keep_alive_interval: Duration,
) -> Result<Resource<Connection>, Error> {
self.connections
.push((self.create_client)(
address,
username,
password,
keep_alive_interval,
)?)
.push((self.create_client).create(address, username, password, keep_alive_interval)?)
.map(Resource::new_own)
.map_err(|_| Error::TooManyConnections)
}
Expand Down
9 changes: 5 additions & 4 deletions crates/factor-outbound-mqtt/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
mod host;

use std::sync::Arc;
use std::time::Duration;

use host::other_error;
use host::CreateClient;
use host::InstanceState;
use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
use spin_core::async_trait;
Expand All @@ -15,14 +15,14 @@ use spin_factors::{
use spin_world::v2::mqtt::{self as v2, Error, Qos};
use tokio::sync::Mutex;

pub use host::MqttClient;
pub use host::{ClientCreator, MqttClient};

pub struct OutboundMqttFactor {
create_client: CreateClient,
create_client: Arc<dyn ClientCreator>,
}

impl OutboundMqttFactor {
pub fn new(create_client: CreateClient) -> Self {
pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
Self { create_client }
}
}
Expand Down Expand Up @@ -64,6 +64,7 @@ impl Factor for OutboundMqttFactor {

impl SelfInstanceBuilder for InstanceState {}

// This is a concrete implementation of the MQTT client using rumqttc.
pub struct NetworkedMqttClient {
inner: rumqttc::AsyncClient,
event_loop: Mutex<rumqttc::EventLoop>,
Expand Down
17 changes: 15 additions & 2 deletions crates/factor-outbound-mqtt/tests/factor_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Arc;
use std::time::Duration;

use anyhow::{bail, Result};
use spin_core::async_trait;
use spin_factor_outbound_mqtt::{MqttClient, OutboundMqttFactor};
use spin_factor_outbound_mqtt::{ClientCreator, MqttClient, OutboundMqttFactor};
use spin_factor_outbound_networking::OutboundNetworkingFactor;
use spin_factor_variables::VariablesFactor;
use spin_factors::{anyhow, RuntimeFactors};
Expand All @@ -23,6 +24,18 @@ impl MqttClient for MockMqttClient {
}
}

impl ClientCreator for MockMqttClient {
fn create(
&self,
_address: String,
_username: String,
_password: String,
_keep_alive_interval: Duration,
) -> Result<Arc<dyn MqttClient>, Error> {
Ok(Arc::new(MockMqttClient {}))
}
}

#[derive(RuntimeFactors)]
struct TestFactors {
variables: VariablesFactor,
Expand All @@ -34,7 +47,7 @@ fn factors() -> TestFactors {
TestFactors {
variables: VariablesFactor::default(),
networking: OutboundNetworkingFactor,
mqtt: OutboundMqttFactor::new(Arc::new(|_, _, _, _| Ok(Box::new(MockMqttClient {})))),
mqtt: OutboundMqttFactor::new(Arc::new(MockMqttClient {})),
}
}

Expand Down

0 comments on commit 398cd3f

Please sign in to comment.