diff --git a/Cargo.lock b/Cargo.lock index 179335b8b..891e4802c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7605,6 +7605,26 @@ dependencies = [ "wasmtime-wasi", ] +[[package]] +name = "spin-factor-outbound-pg" +version = "2.7.0-pre0" +dependencies = [ + "anyhow", + "native-tls", + "postgres-native-tls", + "spin-core", + "spin-factor-outbound-networking", + "spin-factor-variables", + "spin-factor-wasi", + "spin-factors", + "spin-factors-test", + "spin-world", + "table", + "tokio", + "tokio-postgres", + "tracing", +] + [[package]] name = "spin-factor-variables" version = "2.7.0-pre0" diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml new file mode 100644 index 000000000..ca18e93a1 --- /dev/null +++ b/crates/factor-outbound-pg/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "spin-factor-outbound-pg" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } + +[dependencies] +anyhow = "1.0" +native-tls = "0.2.11" +postgres-native-tls = "0.5.0" +spin-core = { path = "../core" } +spin-factor-outbound-networking = { path = "../factor-outbound-networking" } +spin-factors = { path = "../factors" } +spin-world = { path = "../world" } +table = { path = "../table" } +tokio = { version = "1", features = ["rt-multi-thread"] } +tokio-postgres = "0.7.7" +tracing = { workspace = true } + +[dev-dependencies] +spin-factor-variables = { path = "../factor-variables" } +spin-factor-wasi = { path = "../factor-wasi" } +spin-factors-test = { path = "../factors-test" } +tokio = { version = "1", features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs new file mode 100644 index 000000000..63bc9ac91 --- /dev/null +++ b/crates/factor-outbound-pg/src/host.rs @@ -0,0 +1,422 @@ +use anyhow::{anyhow, Result}; +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; +use spin_core::{async_trait, wasmtime::component::Resource}; +use spin_world::v1::postgres as v1; +use spin_world::v1::rdbms_types as v1_types; +use spin_world::v2::postgres::{self as v2, Connection}; +use spin_world::v2::rdbms_types; +use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet}; +use tokio_postgres::{ + config::SslMode, + types::{ToSql, Type}, + Client, NoTls, Row, Socket, +}; +use tracing::instrument; +use tracing::Level; + +use crate::InstanceState; + +impl InstanceState { + async fn open_connection(&mut self, address: &str) -> Result, v2::Error> { + self.connections + .push( + build_client(address) + .await + .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?, + ) + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into())) + .map(Resource::new_own) + } + + async fn get_client(&mut self, connection: Resource) -> Result<&Client, v2::Error> { + self.connections + .get(connection.rep()) + .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into())) + } + + async fn is_address_allowed(&self, address: &str) -> Result { + let Ok(config) = address.parse::() else { + return Ok(false); + }; + for (i, host) in config.get_hosts().iter().enumerate() { + match host { + tokio_postgres::config::Host::Tcp(address) => { + let ports = config.get_ports(); + // The port we use is either: + // * The port at the same index as the host + // * The first port if there is only one port + let port = + ports + .get(i) + .or_else(|| if ports.len() == 1 { ports.get(1) } else { None }); + let port_str = port.map(|p| format!(":{}", p)).unwrap_or_default(); + let url = format!("{address}{port_str}"); + // TODO: Should I be unwrapping this? + if !self.allowed_hosts.check_url(&url, "postgres").await? { + return Ok(false); + } + } + #[cfg(unix)] + tokio_postgres::config::Host::Unix(_) => return Ok(false), + } + } + Ok(true) + } +} + +#[async_trait] +impl v2::Host for InstanceState {} + +#[async_trait] +impl v2::HostConnection for InstanceState { + #[instrument(name = "spin_outbound_pg.open_connection", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql"))] + async fn open(&mut self, address: String) -> Result, v2::Error> { + if !self + .is_address_allowed(&address) + .await + .map_err(|e| v2::Error::Other(e.to_string()))? + { + return Err(v2::Error::ConnectionFailed(format!( + "address {address} is not permitted" + ))); + } + self.open_connection(&address).await + } + + #[instrument(name = "spin_outbound_pg.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn execute( + &mut self, + connection: Resource, + statement: String, + params: Vec, + ) -> Result { + let params: Vec<&(dyn ToSql + Sync)> = params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?; + + let nrow = self + .get_client(connection) + .await? + .execute(&statement, params.as_slice()) + .await + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; + + Ok(nrow) + } + + #[instrument(name = "spin_outbound_pg.query", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn query( + &mut self, + connection: Resource, + statement: String, + params: Vec, + ) -> Result { + let params: Vec<&(dyn ToSql + Sync)> = params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?; + + let results = self + .get_client(connection) + .await? + .query(&statement, params.as_slice()) + .await + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; + + if results.is_empty() { + return Ok(RowSet { + columns: vec![], + rows: vec![], + }); + } + + let columns = infer_columns(&results[0]); + let rows = results + .iter() + .map(convert_row) + .collect::, _>>() + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; + + Ok(RowSet { columns, rows }) + } + + fn drop(&mut self, connection: Resource) -> anyhow::Result<()> { + self.connections.remove(connection.rep()); + Ok(()) + } +} + +impl rdbms_types::Host for InstanceState { + fn convert_error(&mut self, error: v2::Error) -> Result { + Ok(error) + } +} + +fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result<&(dyn ToSql + Sync)> { + match value { + ParameterValue::Boolean(v) => Ok(v), + ParameterValue::Int32(v) => Ok(v), + ParameterValue::Int64(v) => Ok(v), + ParameterValue::Int8(v) => Ok(v), + ParameterValue::Int16(v) => Ok(v), + ParameterValue::Floating32(v) => Ok(v), + ParameterValue::Floating64(v) => Ok(v), + ParameterValue::Uint8(_) + | ParameterValue::Uint16(_) + | ParameterValue::Uint32(_) + | ParameterValue::Uint64(_) => Err(anyhow!("Postgres does not support unsigned integers")), + ParameterValue::Str(v) => Ok(v), + ParameterValue::Binary(v) => Ok(v), + ParameterValue::DbNull => Ok(&PgNull), + } +} + +fn infer_columns(row: &Row) -> Vec { + let mut result = Vec::with_capacity(row.len()); + for index in 0..row.len() { + result.push(infer_column(row, index)); + } + result +} + +fn infer_column(row: &Row, index: usize) -> Column { + let column = &row.columns()[index]; + let name = column.name().to_owned(); + let data_type = convert_data_type(column.type_()); + Column { name, data_type } +} + +fn convert_data_type(pg_type: &Type) -> DbDataType { + match *pg_type { + Type::BOOL => DbDataType::Boolean, + Type::BYTEA => DbDataType::Binary, + Type::FLOAT4 => DbDataType::Floating32, + Type::FLOAT8 => DbDataType::Floating64, + Type::INT2 => DbDataType::Int16, + Type::INT4 => DbDataType::Int32, + Type::INT8 => DbDataType::Int64, + Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str, + _ => { + tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),); + DbDataType::Other + } + } +} + +fn convert_row(row: &Row) -> Result, tokio_postgres::Error> { + let mut result = Vec::with_capacity(row.len()); + for index in 0..row.len() { + result.push(convert_entry(row, index)?); + } + Ok(result) +} + +fn convert_entry(row: &Row, index: usize) -> Result { + let column = &row.columns()[index]; + let value = match column.type_() { + &Type::BOOL => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Boolean(v), + None => DbValue::DbNull, + } + } + &Type::BYTEA => { + let value: Option> = row.try_get(index)?; + match value { + Some(v) => DbValue::Binary(v), + None => DbValue::DbNull, + } + } + &Type::FLOAT4 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Floating32(v), + None => DbValue::DbNull, + } + } + &Type::FLOAT8 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Floating64(v), + None => DbValue::DbNull, + } + } + &Type::INT2 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int16(v), + None => DbValue::DbNull, + } + } + &Type::INT4 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int32(v), + None => DbValue::DbNull, + } + } + &Type::INT8 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int64(v), + None => DbValue::DbNull, + } + } + &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Str(v), + None => DbValue::DbNull, + } + } + t => { + tracing::debug!( + "Couldn't convert Postgres type {} in column {}", + t.name(), + column.name() + ); + DbValue::Unsupported + } + }; + Ok(value) +} + +async fn build_client(address: &str) -> anyhow::Result { + let config = address.parse::()?; + + tracing::debug!("Build new connection: {}", address); + + if config.get_ssl_mode() == SslMode::Disable { + connect(config).await + } else { + connect_tls(config).await + } +} + +async fn connect(config: tokio_postgres::Config) -> anyhow::Result { + let (client, connection) = config.connect(NoTls).await?; + + spawn(connection); + + Ok(client) +} + +async fn connect_tls(config: tokio_postgres::Config) -> anyhow::Result { + let builder = TlsConnector::builder(); + let connector = MakeTlsConnector::new(builder.build()?); + let (client, connection) = config.connect(connector).await?; + + spawn(connection); + + Ok(client) +} + +fn spawn(connection: tokio_postgres::Connection) +where + T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, +{ + tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("Postgres connection error: {}", e); + } + }); +} + +/// Although the Postgres crate converts Rust Option::None to Postgres NULL, +/// it enforces the type of the Option as it does so. (For example, trying to +/// pass an Option::::None to a VARCHAR column fails conversion.) As we +/// do not know expected column types, we instead use a "neutral" custom type +/// which allows conversion to any type but always tells the Postgres crate to +/// treat it as a SQL NULL. +struct PgNull; + +impl ToSql for PgNull { + fn to_sql( + &self, + _ty: &Type, + _out: &mut tokio_postgres::types::private::BytesMut, + ) -> Result> + where + Self: Sized, + { + Ok(tokio_postgres::types::IsNull::Yes) + } + + fn accepts(_ty: &Type) -> bool + where + Self: Sized, + { + true + } + + fn to_sql_checked( + &self, + _ty: &Type, + _out: &mut tokio_postgres::types::private::BytesMut, + ) -> Result> { + Ok(tokio_postgres::types::IsNull::Yes) + } +} + +impl std::fmt::Debug for PgNull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NULL").finish() + } +} + +/// Delegate a function call to the v2::HostConnection implementation +macro_rules! delegate { + ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ + if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? { + return Err(v1::PgError::ConnectionFailed(format!( + "address {} is not permitted", $address + ))); + } + let connection = match $self.open_connection(&$address).await { + Ok(c) => c, + Err(e) => return Err(e.into()), + }; + ::$name($self, connection, $($arg),*) + .await + .map_err(|e| e.into()) + }}; +} + +#[async_trait] +impl v1::Host for InstanceState { + async fn execute( + &mut self, + address: String, + statement: String, + params: Vec, + ) -> Result { + delegate!(self.execute( + address, + statement, + params.into_iter().map(Into::into).collect() + )) + } + + async fn query( + &mut self, + address: String, + statement: String, + params: Vec, + ) -> Result { + delegate!(self.query( + address, + statement, + params.into_iter().map(Into::into).collect() + )) + .map(Into::into) + } + + fn convert_pg_error(&mut self, error: v1::PgError) -> Result { + Ok(error) + } +} diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs new file mode 100644 index 000000000..143666932 --- /dev/null +++ b/crates/factor-outbound-pg/src/lib.rs @@ -0,0 +1,53 @@ +mod host; + +use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor}; +use spin_factors::{ + anyhow, ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, + SelfInstanceBuilder, +}; +use tokio_postgres::Client; + +pub struct OutboundPgFactor; + +impl Factor for OutboundPgFactor { + type RuntimeConfig = (); + type AppState = (); + type InstanceBuilder = InstanceState; + + fn init( + &mut self, + mut ctx: spin_factors::InitContext, + ) -> anyhow::Result<()> { + ctx.link_bindings(spin_world::v1::postgres::add_to_linker)?; + ctx.link_bindings(spin_world::v2::postgres::add_to_linker)?; + Ok(()) + } + + fn configure_app( + &self, + _ctx: ConfigureAppContext, + ) -> anyhow::Result { + Ok(()) + } + + fn prepare( + &self, + _ctx: PrepareContext, + builders: &mut InstanceBuilders, + ) -> anyhow::Result { + let allowed_hosts = builders + .get_mut::()? + .allowed_hosts(); + Ok(InstanceState { + allowed_hosts, + connections: Default::default(), + }) + } +} + +pub struct InstanceState { + allowed_hosts: OutboundAllowedHosts, + connections: table::Table, +} + +impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs new file mode 100644 index 000000000..4f2f78852 --- /dev/null +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -0,0 +1,50 @@ +use anyhow::bail; +use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_pg::OutboundPgFactor; +use spin_factor_variables::{StaticVariables, VariablesFactor}; +use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; +use spin_factors::{anyhow, RuntimeFactors}; +use spin_factors_test::{toml, TestEnvironment}; +use spin_world::v2::postgres::HostConnection; +use spin_world::v2::rdbms_types::Error as PgError; + +#[derive(RuntimeFactors)] +struct TestFactors { + wasi: WasiFactor, + variables: VariablesFactor, + networking: OutboundNetworkingFactor, + pg: OutboundPgFactor, +} + +fn test_env() -> TestEnvironment { + TestEnvironment::default_manifest_extend(toml! { + [component.test-component] + source = "does-not-exist.wasm" + }) +} + +#[tokio::test] +async fn disallowed_host_fails() -> anyhow::Result<()> { + let mut factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor, + pg: OutboundPgFactor, + }; + factors.variables.add_provider_type(StaticVariables)?; + + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let res = state + .pg + .open("postgres://postgres.test:5432/test".to_string()) + .await; + let Err(err) = res else { + bail!("expected Err, got Ok"); + }; + println!("err: {:?}", err); + assert!(matches!(err, PgError::ConnectionFailed(_))); + + Ok(()) +}