Skip to content

Commit

Permalink
Initial commit (untested) of Postgres date-time types
Browse files Browse the repository at this point in the history
Signed-off-by: itowlson <[email protected]>
  • Loading branch information
itowlson committed Oct 17, 2024
1 parent 1154682 commit e14e1e4
Show file tree
Hide file tree
Showing 11 changed files with 646 additions and 53 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = { workspace = true }

[dependencies]
anyhow = { workspace = true }
chrono = "0.4"
native-tls = "0.2"
postgres-native-tls = "0.5"
spin-core = { path = "../core" }
Expand All @@ -14,7 +15,7 @@ spin-factors = { path = "../factors" }
spin-resource-table = { path = "../table" }
spin-world = { path = "../world" }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-postgres = "0.7"
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] }
tracing = { workspace = true }

[dev-dependencies]
Expand Down
70 changes: 53 additions & 17 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use anyhow::{anyhow, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::v2::postgres::{self as v2};
use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet};
use spin_world::spin::postgres::rdbms_types::{
self as v2, Column, DbDataType, DbValue, ParameterValue, RowSet,
};
use tokio_postgres::types::Type;
use tokio_postgres::{config::SslMode, types::ToSql, Row};
use tokio_postgres::{Client as TokioClient, NoTls, Socket};
Expand Down Expand Up @@ -55,13 +56,18 @@ impl Client for TokioClient {
statement: String,
params: Vec<ParameterValue>,
) -> Result<u64, v2::Error> {
let params: Vec<&(dyn ToSql + Sync)> = params
let params = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?;

self.execute(&statement, params.as_slice())
let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();

self.execute(&statement, params_refs.as_slice())
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))
}
Expand All @@ -71,14 +77,19 @@ impl Client for TokioClient {
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v2::Error> {
let params: Vec<&(dyn ToSql + Sync)> = params
let params = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?;

let params_refs: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();

let results = self
.query(&statement, params.as_slice())
.query(&statement, params_refs.as_slice())
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Expand Down Expand Up @@ -111,22 +122,47 @@ where
});
}

fn to_sql_parameter(value: &ParameterValue) -> Result<&(dyn ToSql + Sync)> {
fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
match value {
ParameterValue::Boolean(v) => Ok(v),
ParameterValue::Int32(v) => Ok(v),
ParameterValue::Int64(v) => Ok(v),
ParameterValue::Int8(v) => Ok(v),
ParameterValue::Int16(v) => Ok(v),
ParameterValue::Floating32(v) => Ok(v),
ParameterValue::Floating64(v) => Ok(v),
ParameterValue::Boolean(v) => Ok(Box::new(*v)),
ParameterValue::Int32(v) => Ok(Box::new(*v)),
ParameterValue::Int64(v) => Ok(Box::new(*v)),
ParameterValue::Int8(v) => Ok(Box::new(*v)),
ParameterValue::Int16(v) => Ok(Box::new(*v)),
ParameterValue::Floating32(v) => Ok(Box::new(*v)),
ParameterValue::Floating64(v) => Ok(Box::new(*v)),
ParameterValue::Uint8(_)
| ParameterValue::Uint16(_)
| ParameterValue::Uint32(_)
| ParameterValue::Uint64(_) => Err(anyhow!("Postgres does not support unsigned integers")),
ParameterValue::Str(v) => Ok(v),
ParameterValue::Binary(v) => Ok(v),
ParameterValue::DbNull => Ok(&PgNull),
ParameterValue::Str(v) => Ok(Box::new(v.clone())),
ParameterValue::Binary(v) => Ok(Box::new(v.clone())),
ParameterValue::Date((y, mon, d)) => {
let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
.ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
Ok(Box::new(naive_date))
}
ParameterValue::Time((h, min, s, ns)) => {
let naive_time =
chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
.ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
Ok(Box::new(naive_time))
}
ParameterValue::Datetime((y, mon, d, h, min, s, ns)) => {
let naive_date = chrono::NaiveDate::from_ymd_opt(*y, (*mon).into(), (*d).into())
.ok_or_else(|| anyhow!("invalid date y={y}, m={mon}, d={d}"))?;
let naive_time =
chrono::NaiveTime::from_hms_nano_opt((*h).into(), (*min).into(), (*s).into(), *ns)
.ok_or_else(|| anyhow!("invalid time {h}:{min}:{s}:{ns}"))?;
let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
Ok(Box::new(dt))
}
ParameterValue::Timestamp(v) => {
let ts = chrono::DateTime::<chrono::Utc>::from_timestamp(*v, 0)
.ok_or_else(|| anyhow!("invalid epoch timestamp {v}"))?;
Ok(Box::new(ts))
}
ParameterValue::DbNull => Ok(Box::new(PgNull)),
}
}

Expand Down
125 changes: 101 additions & 24 deletions crates/factor-outbound-pg/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::Result;
use spin_core::{async_trait, wasmtime::component::Resource};
use spin_world::spin::postgres::{self as v3};
use spin_world::v1::postgres as v1;
use spin_world::v1::rdbms_types as v1_types;
use spin_world::v2::postgres::{self as v2, Connection};
use spin_world::v2::rdbms_types;
use spin_world::v2::rdbms_types::{ParameterValue, RowSet};
use spin_world::v2::postgres::{self as v2};
use spin_world::v2::rdbms_types as v2types;
use tracing::field::Empty;
use tracing::instrument;
use tracing::Level;
Expand All @@ -13,21 +13,27 @@ use crate::client::Client;
use crate::InstanceState;

impl<C: Client> InstanceState<C> {
async fn open_connection(&mut self, address: &str) -> Result<Resource<Connection>, v2::Error> {
async fn open_connection<Conn: 'static>(
&mut self,
address: &str,
) -> Result<Resource<Conn>, v3::rdbms_types::Error> {
self.connections
.push(
C::build_client(address)
.await
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
.map_err(|e| v3::rdbms_types::Error::ConnectionFailed(format!("{e:?}")))?,
)
.map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))
.map_err(|_| v3::rdbms_types::Error::ConnectionFailed("too many connections".into()))
.map(Resource::new_own)
}

async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&C, v2::Error> {
async fn get_client<Conn: 'static>(
&mut self,
connection: Resource<Conn>,
) -> Result<&C, v3::rdbms_types::Error> {
self.connections
.get(connection.rep())
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
.ok_or_else(|| v3::rdbms_types::Error::ConnectionFailed("no connection found".into()))
}

async fn is_address_allowed(&self, address: &str) -> Result<bool> {
Expand Down Expand Up @@ -60,20 +66,29 @@ impl<C: Client> InstanceState<C> {
}

#[async_trait]
impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
impl<C: Send + Sync + Client> v3::postgres::Host for InstanceState<C> {}

fn v2_params_to_v3(params: Vec<v2types::ParameterValue>) -> Vec<v3::rdbms_types::ParameterValue> {
params.into_iter().map(|p| p.into()).collect()
}

#[async_trait]
impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
for InstanceState<C>
{
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<Connection>, v2::Error> {
async fn open(
&mut self,
address: String,
) -> Result<Resource<v3::postgres::Connection>, v3::rdbms_types::Error> {
spin_factor_outbound_networking::record_address_fields(&address);

if !self
.is_address_allowed(&address)
.await
.map_err(|e| v2::Error::Other(e.to_string()))?
.map_err(|e| v3::rdbms_types::Error::Other(e.to_string()))?
{
return Err(v2::Error::ConnectionFailed(format!(
return Err(v3::rdbms_types::Error::ConnectionFailed(format!(
"address {address} is not permitted"
)));
}
Expand All @@ -83,10 +98,10 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
#[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
async fn execute(
&mut self,
connection: Resource<Connection>,
connection: Resource<v3::postgres::Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<u64, v2::Error> {
params: Vec<v3::rdbms_types::ParameterValue>,
) -> Result<u64, v3::rdbms_types::Error> {
Ok(self
.get_client(connection)
.await?
Expand All @@ -97,33 +112,39 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
#[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
async fn query(
&mut self,
connection: Resource<Connection>,
connection: Resource<v3::postgres::Connection>,
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v2::Error> {
params: Vec<v3::rdbms_types::ParameterValue>,
) -> Result<v3::rdbms_types::RowSet, v3::rdbms_types::Error> {
Ok(self
.get_client(connection)
.await?
.query(statement, params)
.await?)
}

async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
async fn drop(&mut self, connection: Resource<v3::postgres::Connection>) -> anyhow::Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

impl<C: Send> rdbms_types::Host for InstanceState<C> {
impl<C: Send> v2types::Host for InstanceState<C> {
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
Ok(error)
}
}

/// Delegate a function call to the v2::HostConnection implementation
impl<C: Send> v3::rdbms_types::Host for InstanceState<C> {
fn convert_error(&mut self, error: v3::rdbms_types::Error) -> Result<v3::rdbms_types::Error> {
Ok(error)
}
}

/// Delegate a function call to the v3::HostConnection implementation
macro_rules! delegate {
($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
if !$self.is_address_allowed(&$address).await.map_err(|e| v2::Error::Other(e.to_string()))? {
if !$self.is_address_allowed(&$address).await.map_err(|e| v3::rdbms_types::Error::Other(e.to_string()))? {
return Err(v1::PgError::ConnectionFailed(format!(
"address {} is not permitted", $address
)));
Expand All @@ -132,12 +153,68 @@ macro_rules! delegate {
Ok(c) => c,
Err(e) => return Err(e.into()),
};
<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
<Self as v3::postgres::HostConnection>::$name($self, connection, $($arg),*)
.await
.map_err(|e| e.into())
}};
}

#[async_trait]
impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}

#[async_trait]
impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
spin_factor_outbound_networking::record_address_fields(&address);

if !self
.is_address_allowed(&address)
.await
.map_err(|e| v2::Error::Other(e.to_string()))?
{
return Err(v2::Error::ConnectionFailed(format!(
"address {address} is not permitted"
)));
}
Ok(self.open_connection(&address).await?)
}

#[instrument(name = "spin_outbound_pg.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
async fn execute(
&mut self,
connection: Resource<v2::Connection>,
statement: String,
params: Vec<v2types::ParameterValue>,
) -> Result<u64, v2::Error> {
Ok(self
.get_client(connection)
.await?
.execute(statement, v2_params_to_v3(params))
.await?)
}

#[instrument(name = "spin_outbound_pg.query", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))]
async fn query(
&mut self,
connection: Resource<v2::Connection>,
statement: String,
params: Vec<v2types::ParameterValue>,
) -> Result<v2types::RowSet, v2::Error> {
Ok(self
.get_client(connection)
.await?
.query(statement, v2_params_to_v3(params))
.await?
.into())
}

async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
self.connections.remove(connection.rep());
Ok(())
}
}

#[async_trait]
impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
async fn execute(
Expand Down
1 change: 1 addition & 0 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
) -> anyhow::Result<()> {
ctx.link_bindings(spin_world::v1::postgres::add_to_linker)?;
ctx.link_bindings(spin_world::v2::postgres::add_to_linker)?;
ctx.link_bindings(spin_world::spin::postgres::postgres::add_to_linker)?;
Ok(())
}

Expand Down
8 changes: 4 additions & 4 deletions crates/factor-outbound-pg/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use spin_factor_variables::VariablesFactor;
use spin_factors::{anyhow, RuntimeFactors};
use spin_factors_test::{toml, TestEnvironment};
use spin_world::async_trait;
use spin_world::v2::postgres::HostConnection;
use spin_world::v2::postgres::{self as v2};
use spin_world::v2::rdbms_types::Error as PgError;
use spin_world::v2::rdbms_types::{ParameterValue, RowSet};
use spin_world::spin::postgres::postgres::HostConnection;
use spin_world::spin::postgres::postgres::{self as v2};
use spin_world::spin::postgres::rdbms_types::Error as PgError;
use spin_world::spin::postgres::rdbms_types::{ParameterValue, RowSet};

#[derive(RuntimeFactors)]
struct TestFactors {
Expand Down
Loading

0 comments on commit e14e1e4

Please sign in to comment.