From 32cfd963c3d206eeced508dc3fb58fa88e89d8c3 Mon Sep 17 00:00:00 2001 From: Caleb Schoepp Date: Mon, 15 Jul 2024 15:04:48 -0600 Subject: [PATCH] factors: Add more tests to factor-outbound-pg Signed-off-by: Caleb Schoepp --- crates/factor-outbound-pg/src/client.rs | 75 +++++++++++ crates/factor-outbound-pg/src/host.rs | 61 ++------- crates/factor-outbound-pg/src/lib.rs | 26 ++-- .../factor-outbound-pg/tests/factor_test.rs | 116 ++++++++++++++++-- 4 files changed, 207 insertions(+), 71 deletions(-) create mode 100644 crates/factor-outbound-pg/src/client.rs diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs new file mode 100644 index 0000000000..fb43d474ad --- /dev/null +++ b/crates/factor-outbound-pg/src/client.rs @@ -0,0 +1,75 @@ +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; +use spin_world::async_trait; +use tokio_postgres::{config::SslMode, types::ToSql, Error, Row, ToStatement}; +use tokio_postgres::{Client as PgClient, NoTls, Socket}; + +#[async_trait] +pub trait Client { + async fn build_client(address: &str) -> anyhow::Result + where + Self: Sized; + + async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send; + + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send; +} + +#[async_trait] +impl Client for PgClient { + async fn build_client(address: &str) -> anyhow::Result + where + Self: Sized, + { + let config = address.parse::()?; + + tracing::debug!("Build new connection: {}", address); + + if config.get_ssl_mode() == SslMode::Disable { + async move { + let (client, connection) = config.connect(NoTls).await?; + spawn(connection); + Ok(client) + } + .await + } else { + async move { + let builder = TlsConnector::builder(); + let connector = MakeTlsConnector::new(builder.build()?); + let (client, connection) = config.connect(connector).await?; + spawn(connection); + Ok(client) + } + .await + } + } + + async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + self.execute(query, params).await + } + + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query(query, params).await + } +} + +fn spawn(connection: tokio_postgres::Connection) +where + T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, +{ + tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("Postgres connection error: {}", e); + } + }); +} diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 63bc9ac91b..62c79bbf1a 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,6 +1,4 @@ use anyhow::{anyhow, Result}; -use native_tls::TlsConnector; -use postgres_native_tls::MakeTlsConnector; use spin_core::{async_trait, wasmtime::component::Resource}; use spin_world::v1::postgres as v1; use spin_world::v1::rdbms_types as v1_types; @@ -8,20 +6,20 @@ use spin_world::v2::postgres::{self as v2, Connection}; use spin_world::v2::rdbms_types; use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet}; use tokio_postgres::{ - config::SslMode, types::{ToSql, Type}, - Client, NoTls, Row, Socket, + Row, }; use tracing::instrument; use tracing::Level; +use crate::client::Client; use crate::InstanceState; -impl InstanceState { +impl InstanceState { async fn open_connection(&mut self, address: &str) -> Result, v2::Error> { self.connections .push( - build_client(address) + C::build_client(address) .await .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?, ) @@ -29,7 +27,7 @@ impl InstanceState { .map(Resource::new_own) } - async fn get_client(&mut self, connection: Resource) -> Result<&Client, v2::Error> { + async fn get_client(&mut self, connection: Resource) -> Result<&C, v2::Error> { self.connections .get(connection.rep()) .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into())) @@ -66,10 +64,10 @@ impl InstanceState { } #[async_trait] -impl v2::Host for InstanceState {} +impl v2::Host for InstanceState {} #[async_trait] -impl v2::HostConnection for InstanceState { +impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_pg.open_connection", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql"))] async fn open(&mut self, address: String) -> Result, v2::Error> { if !self @@ -150,7 +148,7 @@ impl v2::HostConnection for InstanceState { } } -impl rdbms_types::Host for InstanceState { +impl rdbms_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) } @@ -286,47 +284,6 @@ fn convert_entry(row: &Row, index: usize) -> Result anyhow::Result { - let config = address.parse::()?; - - tracing::debug!("Build new connection: {}", address); - - if config.get_ssl_mode() == SslMode::Disable { - connect(config).await - } else { - connect_tls(config).await - } -} - -async fn connect(config: tokio_postgres::Config) -> anyhow::Result { - let (client, connection) = config.connect(NoTls).await?; - - spawn(connection); - - Ok(client) -} - -async fn connect_tls(config: tokio_postgres::Config) -> anyhow::Result { - let builder = TlsConnector::builder(); - let connector = MakeTlsConnector::new(builder.build()?); - let (client, connection) = config.connect(connector).await?; - - spawn(connection); - - Ok(client) -} - -fn spawn(connection: tokio_postgres::Connection) -where - T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, -{ - tokio::spawn(async move { - if let Err(e) = connection.await { - tracing::error!("Postgres connection error: {}", e); - } - }); -} - /// Although the Postgres crate converts Rust Option::None to Postgres NULL, /// it enforces the type of the Option as it does so. (For example, trying to /// pass an Option::::None to a VARCHAR column fails conversion.) As we @@ -388,7 +345,7 @@ macro_rules! delegate { } #[async_trait] -impl v1::Host for InstanceState { +impl v1::Host for InstanceState { async fn execute( &mut self, address: String, diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 1436669321..a5ead7769c 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,18 +1,22 @@ +pub mod client; mod host; +use client::Client; use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor}; use spin_factors::{ anyhow, ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; -use tokio_postgres::Client; +use tokio_postgres::Client as PgClient; -pub struct OutboundPgFactor; +pub struct OutboundPgFactor { + _phantom: std::marker::PhantomData, +} -impl Factor for OutboundPgFactor { +impl Factor for OutboundPgFactor { type RuntimeConfig = (); type AppState = (); - type InstanceBuilder = InstanceState; + type InstanceBuilder = InstanceState; fn init( &mut self, @@ -45,9 +49,17 @@ impl Factor for OutboundPgFactor { } } -pub struct InstanceState { +impl Default for OutboundPgFactor { + fn default() -> Self { + Self { + _phantom: Default::default(), + } + } +} + +pub struct InstanceState { allowed_hosts: OutboundAllowedHosts, - connections: table::Table, + connections: table::Table, } -impl SelfInstanceBuilder for InstanceState {} +impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 4f2f788521..7baaa6b80b 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -1,39 +1,51 @@ -use anyhow::bail; +use anyhow::{bail, Result}; use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_pg::client::Client; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::{StaticVariables, VariablesFactor}; use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; +use spin_world::async_trait; use spin_world::v2::postgres::HostConnection; use spin_world::v2::rdbms_types::Error as PgError; +use tokio_postgres::types::ToSql; +use tokio_postgres::{Error, Row, ToStatement}; #[derive(RuntimeFactors)] struct TestFactors { wasi: WasiFactor, variables: VariablesFactor, networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, + pg: OutboundPgFactor, +} + +fn factors() -> Result { + let mut f = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor, + pg: OutboundPgFactor::::default(), + }; + f.variables.add_provider_type(StaticVariables)?; + Ok(f) } fn test_env() -> TestEnvironment { TestEnvironment::default_manifest_extend(toml! { [component.test-component] source = "does-not-exist.wasm" + allowed_outbound_hosts = ["postgres://*:*"] }) } #[tokio::test] async fn disallowed_host_fails() -> anyhow::Result<()> { - let mut factors = TestFactors { - wasi: WasiFactor::new(DummyFilesMounter), - variables: VariablesFactor::default(), - networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, - }; - factors.variables.add_provider_type(StaticVariables)?; - - let env = test_env(); + let factors = factors()?; + let env = TestEnvironment::default_manifest_extend(toml! { + [component.test-component] + source = "does-not-exist.wasm" + }); let mut state = env.build_instance_state(factors).await?; let res = state @@ -43,8 +55,88 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { let Err(err) = res else { bail!("expected Err, got Ok"); }; - println!("err: {:?}", err); assert!(matches!(err, PgError::ConnectionFailed(_))); Ok(()) } + +#[tokio::test] +async fn allowed_host_succeeds() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let res = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await; + let Ok(_) = res else { + bail!("expected Ok, got Err"); + }; + + Ok(()) +} + +#[tokio::test] +async fn exercise_execute() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let connection = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await?; + + state + .pg + .execute(connection, "SELECT * FROM test".to_string(), vec![]) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn exercise_query() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let connection = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await?; + + state + .pg + .query(connection, "SELECT * FROM test".to_string(), vec![]) + .await?; + + Ok(()) +} + +pub struct MockClient {} + +#[async_trait] +impl Client for MockClient { + async fn build_client(_address: &str) -> anyhow::Result + where + Self: Sized, + { + Ok(MockClient {}) + } + + async fn execute(&self, _query: &T, _params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + Ok(0) + } + + async fn query(&self, _query: &T, _params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + Ok(vec![]) + } +}