Skip to content

Commit

Permalink
factors: Add more tests to factor-outbound-pg
Browse files Browse the repository at this point in the history
Signed-off-by: Caleb Schoepp <[email protected]>
  • Loading branch information
calebschoepp committed Jul 15, 2024
1 parent 112c4e4 commit 32cfd96
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 71 deletions.
75 changes: 75 additions & 0 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use tokio_postgres::{config::SslMode, types::ToSql, Error, Row, ToStatement};
use tokio_postgres::{Client as PgClient, NoTls, Socket};

#[async_trait]
pub trait Client {
async fn build_client(address: &str) -> anyhow::Result<Self>
where
Self: Sized;

async fn execute<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Sync + Send;

async fn query<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Sync + Send;
}

#[async_trait]
impl Client for PgClient {
async fn build_client(address: &str) -> anyhow::Result<Self>
where
Self: Sized,
{
let config = address.parse::<tokio_postgres::Config>()?;

tracing::debug!("Build new connection: {}", address);

if config.get_ssl_mode() == SslMode::Disable {
async move {
let (client, connection) = config.connect(NoTls).await?;
spawn(connection);
Ok(client)
}
.await
} else {
async move {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
let (client, connection) = config.connect(connector).await?;
spawn(connection);
Ok(client)
}
.await
}
}

async fn execute<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Sync + Send,
{
self.execute(query, params).await
}

async fn query<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Sync + Send,
{
self.query(query, params).await
}
}

fn spawn<T>(connection: tokio_postgres::Connection<Socket, T>)
where
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
{
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Postgres connection error: {}", e);
}
});
}
61 changes: 9 additions & 52 deletions crates/factor-outbound-pg/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
use anyhow::{anyhow, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_core::{async_trait, wasmtime::component::Resource};
use spin_world::v1::postgres as v1;
use spin_world::v1::rdbms_types as v1_types;
use spin_world::v2::postgres::{self as v2, Connection};
use spin_world::v2::rdbms_types;
use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet};
use tokio_postgres::{
config::SslMode,
types::{ToSql, Type},
Client, NoTls, Row, Socket,
Row,
};
use tracing::instrument;
use tracing::Level;

use crate::client::Client;
use crate::InstanceState;

impl InstanceState {
impl<C: Client> InstanceState<C> {
async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
self.connections
.push(
build_client(address)
C::build_client(address)
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
.map(Resource::new_own)
}

async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&Client, v2::Error> {
async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&C, v2::Error> {
self.connections
.get(connection.rep())
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
Expand Down Expand Up @@ -66,10 +64,10 @@ impl InstanceState {
}

#[async_trait]
impl v2::Host for InstanceState {}
impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}

#[async_trait]
impl v2::HostConnection for InstanceState {
impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
#[instrument(name = "spin_outbound_pg.open_connection", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql"))]
async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
if !self
Expand Down Expand Up @@ -150,7 +148,7 @@ impl v2::HostConnection for InstanceState {
}
}

impl rdbms_types::Host for InstanceState {
impl<C: Send> rdbms_types::Host for InstanceState<C> {
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
Ok(error)
}
Expand Down Expand Up @@ -286,47 +284,6 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
Ok(value)
}

async fn build_client(address: &str) -> anyhow::Result<Client> {
let config = address.parse::<tokio_postgres::Config>()?;

tracing::debug!("Build new connection: {}", address);

if config.get_ssl_mode() == SslMode::Disable {
connect(config).await
} else {
connect_tls(config).await
}
}

async fn connect(config: tokio_postgres::Config) -> anyhow::Result<Client> {
let (client, connection) = config.connect(NoTls).await?;

spawn(connection);

Ok(client)
}

async fn connect_tls(config: tokio_postgres::Config) -> anyhow::Result<Client> {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
let (client, connection) = config.connect(connector).await?;

spawn(connection);

Ok(client)
}

fn spawn<T>(connection: tokio_postgres::Connection<Socket, T>)
where
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
{
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Postgres connection error: {}", e);
}
});
}

/// Although the Postgres crate converts Rust Option::None to Postgres NULL,
/// it enforces the type of the Option as it does so. (For example, trying to
/// pass an Option::<i32>::None to a VARCHAR column fails conversion.) As we
Expand Down Expand Up @@ -388,7 +345,7 @@ macro_rules! delegate {
}

#[async_trait]
impl v1::Host for InstanceState {
impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
async fn execute(
&mut self,
address: String,
Expand Down
26 changes: 19 additions & 7 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
pub mod client;
mod host;

use client::Client;
use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor};
use spin_factors::{
anyhow, ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors,
SelfInstanceBuilder,
};
use tokio_postgres::Client;
use tokio_postgres::Client as PgClient;

pub struct OutboundPgFactor;
pub struct OutboundPgFactor<C = PgClient> {
_phantom: std::marker::PhantomData<C>,
}

impl Factor for OutboundPgFactor {
impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
type RuntimeConfig = ();
type AppState = ();
type InstanceBuilder = InstanceState;
type InstanceBuilder = InstanceState<C>;

fn init<T: RuntimeFactors>(
&mut self,
Expand Down Expand Up @@ -45,9 +49,17 @@ impl Factor for OutboundPgFactor {
}
}

pub struct InstanceState {
impl<C> Default for OutboundPgFactor<C> {
fn default() -> Self {
Self {
_phantom: Default::default(),
}
}
}

pub struct InstanceState<C> {
allowed_hosts: OutboundAllowedHosts,
connections: table::Table<Client>,
connections: table::Table<C>,
}

impl SelfInstanceBuilder for InstanceState {}
impl<C: Send + 'static> SelfInstanceBuilder for InstanceState<C> {}
116 changes: 104 additions & 12 deletions crates/factor-outbound-pg/tests/factor_test.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,51 @@
use anyhow::bail;
use anyhow::{bail, Result};
use spin_factor_outbound_networking::OutboundNetworkingFactor;
use spin_factor_outbound_pg::client::Client;
use spin_factor_outbound_pg::OutboundPgFactor;
use spin_factor_variables::{StaticVariables, VariablesFactor};
use spin_factor_wasi::{DummyFilesMounter, WasiFactor};
use spin_factors::{anyhow, RuntimeFactors};
use spin_factors_test::{toml, TestEnvironment};
use spin_world::async_trait;
use spin_world::v2::postgres::HostConnection;
use spin_world::v2::rdbms_types::Error as PgError;
use tokio_postgres::types::ToSql;
use tokio_postgres::{Error, Row, ToStatement};

#[derive(RuntimeFactors)]
struct TestFactors {
wasi: WasiFactor,
variables: VariablesFactor,
networking: OutboundNetworkingFactor,
pg: OutboundPgFactor,
pg: OutboundPgFactor<MockClient>,
}

fn factors() -> Result<TestFactors> {
let mut f = TestFactors {
wasi: WasiFactor::new(DummyFilesMounter),
variables: VariablesFactor::default(),
networking: OutboundNetworkingFactor,
pg: OutboundPgFactor::<MockClient>::default(),
};
f.variables.add_provider_type(StaticVariables)?;
Ok(f)
}

fn test_env() -> TestEnvironment {
TestEnvironment::default_manifest_extend(toml! {
[component.test-component]
source = "does-not-exist.wasm"
allowed_outbound_hosts = ["postgres://*:*"]
})
}

#[tokio::test]
async fn disallowed_host_fails() -> anyhow::Result<()> {
let mut factors = TestFactors {
wasi: WasiFactor::new(DummyFilesMounter),
variables: VariablesFactor::default(),
networking: OutboundNetworkingFactor,
pg: OutboundPgFactor,
};
factors.variables.add_provider_type(StaticVariables)?;

let env = test_env();
let factors = factors()?;
let env = TestEnvironment::default_manifest_extend(toml! {
[component.test-component]
source = "does-not-exist.wasm"
});
let mut state = env.build_instance_state(factors).await?;

let res = state
Expand All @@ -43,8 +55,88 @@ async fn disallowed_host_fails() -> anyhow::Result<()> {
let Err(err) = res else {
bail!("expected Err, got Ok");
};
println!("err: {:?}", err);
assert!(matches!(err, PgError::ConnectionFailed(_)));

Ok(())
}

#[tokio::test]
async fn allowed_host_succeeds() -> anyhow::Result<()> {
let factors = factors()?;
let env = test_env();
let mut state = env.build_instance_state(factors).await?;

let res = state
.pg
.open("postgres://localhost:5432/test".to_string())
.await;
let Ok(_) = res else {
bail!("expected Ok, got Err");
};

Ok(())
}

#[tokio::test]
async fn exercise_execute() -> anyhow::Result<()> {
let factors = factors()?;
let env = test_env();
let mut state = env.build_instance_state(factors).await?;

let connection = state
.pg
.open("postgres://localhost:5432/test".to_string())
.await?;

state
.pg
.execute(connection, "SELECT * FROM test".to_string(), vec![])
.await?;

Ok(())
}

#[tokio::test]
async fn exercise_query() -> anyhow::Result<()> {
let factors = factors()?;
let env = test_env();
let mut state = env.build_instance_state(factors).await?;

let connection = state
.pg
.open("postgres://localhost:5432/test".to_string())
.await?;

state
.pg
.query(connection, "SELECT * FROM test".to_string(), vec![])
.await?;

Ok(())
}

pub struct MockClient {}

#[async_trait]
impl Client for MockClient {
async fn build_client(_address: &str) -> anyhow::Result<Self>
where
Self: Sized,
{
Ok(MockClient {})
}

async fn execute<T>(&self, _query: &T, _params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Sync + Send,
{
Ok(0)
}

async fn query<T>(&self, _query: &T, _params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Sync + Send,
{
Ok(vec![])
}
}

0 comments on commit 32cfd96

Please sign in to comment.