diff --git a/Cargo.lock b/Cargo.lock index fbe4a8ffb542..3e2db866ce7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -520,15 +520,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - -[[package]] -name = "cfg_aliases" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77e53693616d3075149f4ead59bdeecd204ac6b8192d8969757601b74bddf00f" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" @@ -676,6 +670,16 @@ dependencies = [ "unreachable", ] +[[package]] +name = "concat-idents" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f76990911f2267d837d9d0ad060aa63aaad170af40904b29461734c339030d4d" +dependencies = [ + "quote", + "syn 2.0.48", +] + [[package]] name = "connection-string" version = "0.2.0" @@ -3595,8 +3599,9 @@ dependencies = [ "bit-vec", "byteorder", "bytes", - "cfg_aliases 0.1.1", + "cfg_aliases", "chrono", + "concat-idents", "connection-string", "crosstarget-utils", "either", @@ -4261,7 +4266,7 @@ name = "request-handlers" version = "0.1.0" dependencies = [ "bigdecimal", - "cfg_aliases 0.2.0", + "cfg_aliases", "codspeed-criterion-compat", "connection-string", "dmmf", @@ -5428,9 +5433,9 @@ dependencies = [ [[package]] name = "tiberius" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66303a42b7c5daffb95c10cd8f3007a9c29b3e90128cf42b3738f58102aa2516" +checksum = "091052ba8f20c1e14f85913a5242a663a09d17ff4c0137b9b1f0735cb3c5dabc" dependencies = [ "async-native-tls", "async-trait", diff --git a/quaint/.envrc b/quaint/.envrc index 226b0ae234c1..6c7b5116c89a 100644 --- a/quaint/.envrc +++ b/quaint/.envrc @@ -1,7 +1,7 @@ export TEST_MYSQL="mysql://root:prisma@localhost:3306/prisma" export TEST_MYSQL8="mysql://root:prisma@localhost:3307/prisma" export TEST_MYSQL_MARIADB="mysql://root:prisma@localhost:3308/prisma" -export TEST_PSQL="postgres://postgres:prisma@localhost:5432/postgres" +export TEST_PSQL="postgresql://postgres:prisma@localhost:5432/postgres" export TEST_CRDB="postgresql://prisma@127.0.0.1:26259/postgres" export TEST_MSSQL="jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true" if command -v nix-shell &> /dev/null diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index d30e27408782..ed0e3f9c69ef 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -93,6 +93,7 @@ serde = { version = "1.0" } sqlformat = { version = "0.2.3", optional = true } uuid.workspace = true crosstarget-utils = { path = "../libs/crosstarget-utils" } +concat-idents = "1.1.5" [dev-dependencies] once_cell = "1.3" @@ -125,12 +126,12 @@ features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] -version = "0.11.6" +version = "0.11.8" optional = true features = ["sql-browser-tokio", "chrono", "bigdecimal"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.tiberius] -version = "0.11.2" +version = "0.11.8" optional = true default-features = false features = [ @@ -183,4 +184,4 @@ features = ["compat"] optional = true [build-dependencies] -cfg_aliases = "0.1.0" +cfg_aliases = "0.2.1" diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index e5b0f760be0d..1339bc473db6 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -9,6 +9,7 @@ //! implement the [Queryable](trait.Queryable.html) trait for generalized //! querying interface. +mod column_type; mod connection_info; pub mod external; @@ -24,6 +25,7 @@ mod transaction; mod type_identifier; pub use self::result_set::*; +pub use column_type::*; pub use connection_info::*; #[cfg(native)] diff --git a/quaint/src/connector/column_type.rs b/quaint/src/connector/column_type.rs new file mode 100644 index 000000000000..f01c7aa7953d --- /dev/null +++ b/quaint/src/connector/column_type.rs @@ -0,0 +1,177 @@ +#[cfg(not(target_arch = "wasm32"))] +use super::TypeIdentifier; + +use crate::{Value, ValueType}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ColumnType { + Int32, + Int64, + Float, + Double, + Text, + Bytes, + Boolean, + Char, + Numeric, + Json, + Xml, + Uuid, + DateTime, + Date, + Time, + Enum, + + Int32Array, + Int64Array, + FloatArray, + DoubleArray, + TextArray, + CharArray, + BytesArray, + BooleanArray, + NumericArray, + JsonArray, + XmlArray, + UuidArray, + DateTimeArray, + DateArray, + TimeArray, + + Unknown, +} + +impl ColumnType { + pub fn is_unknown(&self) -> bool { + matches!(self, ColumnType::Unknown) + } +} + +impl std::fmt::Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ColumnType::Int32 => write!(f, "int"), + ColumnType::Int64 => write!(f, "bigint"), + ColumnType::Float => write!(f, "float"), + ColumnType::Double => write!(f, "double"), + ColumnType::Text => write!(f, "string"), + ColumnType::Enum => write!(f, "enum"), + ColumnType::Bytes => write!(f, "bytes"), + ColumnType::Boolean => write!(f, "bool"), + ColumnType::Char => write!(f, "char"), + ColumnType::Numeric => write!(f, "decimal"), + ColumnType::Json => write!(f, "json"), + ColumnType::Xml => write!(f, "xml"), + ColumnType::Uuid => write!(f, "uuid"), + ColumnType::DateTime => write!(f, "datetime"), + ColumnType::Date => write!(f, "date"), + ColumnType::Time => write!(f, "time"), + ColumnType::Int32Array => write!(f, "int-array"), + ColumnType::Int64Array => write!(f, "bigint-array"), + ColumnType::FloatArray => write!(f, "float-array"), + ColumnType::DoubleArray => write!(f, "double-array"), + ColumnType::TextArray => write!(f, "string-array"), + ColumnType::BytesArray => write!(f, "bytes-array"), + ColumnType::BooleanArray => write!(f, "bool-array"), + ColumnType::CharArray => write!(f, "char-array"), + ColumnType::NumericArray => write!(f, "decimal-array"), + ColumnType::JsonArray => write!(f, "json-array"), + ColumnType::XmlArray => write!(f, "xml-array"), + ColumnType::UuidArray => write!(f, "uuid-array"), + ColumnType::DateTimeArray => write!(f, "datetime-array"), + ColumnType::DateArray => write!(f, "date-array"), + ColumnType::TimeArray => write!(f, "time-array"), + + ColumnType::Unknown => write!(f, "unknown"), + } + } +} + +impl From<&Value<'_>> for ColumnType { + fn from(value: &Value<'_>) -> Self { + Self::from(&value.typed) + } +} + +impl From<&ValueType<'_>> for ColumnType { + fn from(value: &ValueType) -> Self { + match value { + ValueType::Int32(_) => ColumnType::Int32, + ValueType::Int64(_) => ColumnType::Int64, + ValueType::Float(_) => ColumnType::Float, + ValueType::Double(_) => ColumnType::Double, + ValueType::Text(_) => ColumnType::Text, + ValueType::Enum(_, _) => ColumnType::Enum, + ValueType::EnumArray(_, _) => ColumnType::TextArray, + ValueType::Bytes(_) => ColumnType::Bytes, + ValueType::Boolean(_) => ColumnType::Boolean, + ValueType::Char(_) => ColumnType::Char, + ValueType::Numeric(_) => ColumnType::Numeric, + ValueType::Json(_) => ColumnType::Json, + ValueType::Xml(_) => ColumnType::Xml, + ValueType::Uuid(_) => ColumnType::Uuid, + ValueType::DateTime(_) => ColumnType::DateTime, + ValueType::Date(_) => ColumnType::Date, + ValueType::Time(_) => ColumnType::Time, + ValueType::Array(Some(vals)) if !vals.is_empty() => match &vals[0].typed { + ValueType::Int32(_) => ColumnType::Int32Array, + ValueType::Int64(_) => ColumnType::Int64Array, + ValueType::Float(_) => ColumnType::FloatArray, + ValueType::Double(_) => ColumnType::DoubleArray, + ValueType::Text(_) => ColumnType::TextArray, + ValueType::Enum(_, _) => ColumnType::TextArray, + ValueType::Bytes(_) => ColumnType::BytesArray, + ValueType::Boolean(_) => ColumnType::BooleanArray, + ValueType::Char(_) => ColumnType::CharArray, + ValueType::Numeric(_) => ColumnType::NumericArray, + ValueType::Json(_) => ColumnType::JsonArray, + ValueType::Xml(_) => ColumnType::TextArray, + ValueType::Uuid(_) => ColumnType::UuidArray, + ValueType::DateTime(_) => ColumnType::DateTimeArray, + ValueType::Date(_) => ColumnType::DateArray, + ValueType::Time(_) => ColumnType::TimeArray, + ValueType::Array(_) => ColumnType::Unknown, + ValueType::EnumArray(_, _) => ColumnType::Unknown, + }, + ValueType::Array(_) => ColumnType::Unknown, + } + } +} + +impl ColumnType { + #[cfg(not(target_arch = "wasm32"))] + pub(crate) fn from_type_identifier(value: T) -> Self + where + T: TypeIdentifier, + { + if value.is_bool() { + ColumnType::Boolean + } else if value.is_bytes() { + ColumnType::Bytes + } else if value.is_date() { + ColumnType::Date + } else if value.is_datetime() { + ColumnType::DateTime + } else if value.is_time() { + ColumnType::Time + } else if value.is_double() { + ColumnType::Double + } else if value.is_float() { + ColumnType::Float + } else if value.is_int32() { + ColumnType::Int32 + } else if value.is_int64() { + ColumnType::Int64 + } else if value.is_enum() { + ColumnType::Enum + } else if value.is_json() { + ColumnType::Json + } else if value.is_real() { + ColumnType::Numeric + } else if value.is_text() { + ColumnType::Text + } else { + ColumnType::Unknown + } + } +} diff --git a/quaint/src/connector/mssql/native/column_type.rs b/quaint/src/connector/mssql/native/column_type.rs new file mode 100644 index 000000000000..a133b883b34e --- /dev/null +++ b/quaint/src/connector/mssql/native/column_type.rs @@ -0,0 +1,43 @@ +use crate::connector::ColumnType; +use tiberius::{Column, ColumnType as MssqlColumnType}; + +impl From<&Column> for ColumnType { + fn from(value: &Column) -> Self { + match value.column_type() { + MssqlColumnType::Null => ColumnType::Unknown, + + MssqlColumnType::BigVarChar + | MssqlColumnType::BigChar + | MssqlColumnType::NVarchar + | MssqlColumnType::NChar + | MssqlColumnType::Text + | MssqlColumnType::NText => ColumnType::Text, + + MssqlColumnType::Xml => ColumnType::Xml, + + MssqlColumnType::Bit | MssqlColumnType::Bitn => ColumnType::Boolean, + MssqlColumnType::Int1 | MssqlColumnType::Int2 | MssqlColumnType::Int4 => ColumnType::Int32, + MssqlColumnType::Int8 | MssqlColumnType::Intn => ColumnType::Int64, + + MssqlColumnType::Datetime2 + | MssqlColumnType::Datetime4 + | MssqlColumnType::Datetime + | MssqlColumnType::Datetimen + | MssqlColumnType::DatetimeOffsetn => ColumnType::DateTime, + + MssqlColumnType::Float4 => ColumnType::Float, + MssqlColumnType::Float8 | MssqlColumnType::Money | MssqlColumnType::Money4 | MssqlColumnType::Floatn => { + ColumnType::Double + } + MssqlColumnType::Guid => ColumnType::Uuid, + MssqlColumnType::Decimaln | MssqlColumnType::Numericn => ColumnType::Numeric, + MssqlColumnType::Daten => ColumnType::Date, + MssqlColumnType::Timen => ColumnType::Time, + MssqlColumnType::BigVarBin | MssqlColumnType::BigBinary | MssqlColumnType::Image => ColumnType::Bytes, + + MssqlColumnType::Udt | MssqlColumnType::SSVariant => { + unreachable!("UDT and SSVariant types are not supported by Tiberius.") + } + } + } +} diff --git a/quaint/src/connector/mssql/native/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs index c6f2b1f37f48..5d2eb2eb08b8 100644 --- a/quaint/src/connector/mssql/native/conversion.rs +++ b/quaint/src/connector/mssql/native/conversion.rs @@ -3,8 +3,7 @@ use crate::ast::{Value, ValueType}; use bigdecimal::BigDecimal; use std::{borrow::Cow, convert::TryFrom}; -use tiberius::ToSql; -use tiberius::{ColumnData, FromSql, IntoSql}; +use tiberius::{ColumnData, FromSql, IntoSql, ToSql}; impl<'a> IntoSql<'a> for &'a Value<'a> { fn into_sql(self) -> ColumnData<'a> { diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 124e14ac94d0..18cad722c7c1 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,6 +1,7 @@ //! Definitions for the MSSQL connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `mssql-native` feature. +mod column_type; mod conversion; mod error; @@ -9,7 +10,7 @@ use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions} use crate::{ ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + connector::{metrics, queryable::*, ColumnType as QuaintColumnType, DefaultTransaction, ResultSet}, visitor::{self, Visitor}, }; use async_trait::async_trait; @@ -144,6 +145,10 @@ impl Queryable for Mssql { Some(rows) => { let mut columns_set = false; let mut columns = Vec::new(); + + let mut types_set = false; + let mut types = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); for row in rows.into_iter() { @@ -152,6 +157,11 @@ impl Queryable for Mssql { columns_set = true; } + if !types_set { + types = row.columns().iter().map(QuaintColumnType::from).collect(); + types_set = true; + } + let mut values: Vec> = Vec::with_capacity(row.len()); for val in row.into_iter() { @@ -161,9 +171,9 @@ impl Queryable for Mssql { result_rows.push(values); } - Ok(ResultSet::new(columns, result_rows)) + Ok(ResultSet::new(columns, types, result_rows)) } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), + None => Ok(ResultSet::new(Vec::new(), Vec::new(), Vec::new())), } }) .await diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index f18fd6a0b94a..064bc152d25f 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the MySQL connector. //! This module is only available with the `mysql` feature. mod defaults; + pub(crate) mod error; pub(crate) mod url; diff --git a/quaint/src/connector/mysql/native/column_type.rs b/quaint/src/connector/mysql/native/column_type.rs new file mode 100644 index 000000000000..801cd697460a --- /dev/null +++ b/quaint/src/connector/mysql/native/column_type.rs @@ -0,0 +1,8 @@ +use crate::connector::ColumnType; +use mysql_async::Column as MysqlColumn; + +impl From<&MysqlColumn> for ColumnType { + fn from(value: &MysqlColumn) -> Self { + ColumnType::from_type_identifier(value) + } +} diff --git a/quaint/src/connector/mysql/native/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs index cccb1dc3130a..1a2d065f03af 100644 --- a/quaint/src/connector/mysql/native/conversion.rs +++ b/quaint/src/connector/mysql/native/conversion.rs @@ -80,7 +80,7 @@ pub fn conv_params(params: &[Value<'_>]) -> crate::Result { } } -impl TypeIdentifier for my::Column { +impl TypeIdentifier for &my::Column { fn is_real(&self) -> bool { use ColumnType::*; @@ -175,14 +175,19 @@ impl TypeIdentifier for my::Column { fn is_bytes(&self) -> bool { use ColumnType::*; - let is_a_blob = matches!( + let is_bytes = matches!( self.column_type(), - MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB + MYSQL_TYPE_TINY_BLOB + | MYSQL_TYPE_MEDIUM_BLOB + | MYSQL_TYPE_LONG_BLOB + | MYSQL_TYPE_BLOB + | MYSQL_TYPE_VAR_STRING + | MYSQL_TYPE_STRING ) && self.character_set() == 63; let is_bits = self.column_type() == MYSQL_TYPE_BIT && self.column_length() > 1; - is_a_blob || is_bits + is_bytes || is_bits } fn is_bool(&self) -> bool { @@ -268,6 +273,20 @@ impl TakeRow for my::Row { })?), my::Value::Float(f) => Value::from(f), my::Value::Double(f) => Value::from(f), + my::Value::Date(year, month, day, _, _, _, _) if column.is_date() => { + if day == 0 || month == 0 { + let msg = format!( + "The column `{}` contained an invalid datetime value with either day or month set to zero.", + column.name_str() + ); + let kind = ErrorKind::value_out_of_range(msg); + return Err(Error::builder(kind).build()); + } + + let date = NaiveDate::from_ymd_opt(year.into(), month.into(), day.into()).unwrap(); + + Value::date(date) + } my::Value::Date(year, month, day, hour, min, sec, micro) => { if day == 0 || month == 0 { let msg = format!( diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 4ffdfc88b4cf..73dc6e529272 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,11 +1,12 @@ //! Definitions for the MySQL connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `mysql-native` feature. +mod column_type; mod conversion; mod error; pub(crate) use crate::connector::mysql::MysqlUrl; -use crate::connector::{timeout, IsolationLevel}; +use crate::connector::{timeout, ColumnType, IsolationLevel}; use crate::{ ast::{Query, Value}, @@ -197,15 +198,40 @@ impl Queryable for Mysql { self.prepared(sql, |stmt| async move { let mut conn = self.conn.lock().await; let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); + + let mut result_rows = Vec::with_capacity(rows.len()); + let mut columns: Vec = Vec::new(); + let mut column_types: Vec = Vec::new(); + + let mut columns_set = false; for mut row in rows { - result_set.rows.push(row.take_result_row()?); + let row = row.take_result_row()?; + + if !columns_set { + for (idx, _) in row.iter().enumerate() { + let maybe_column = stmt.columns().get(idx); + // `mysql_async` does not return columns in `ResultSet` when a call to a stored procedure is done + // See https://github.com/prisma/prisma/issues/6173 + let column = maybe_column + .map(|col| col.name_str().into_owned()) + .unwrap_or_else(|| format!("f{idx}")); + let column_type = maybe_column.map(ColumnType::from).unwrap_or(ColumnType::Unknown); + + columns.push(column); + column_types.push(column_type); + } + + columns_set = true; + } + + result_rows.push(row); } + let mut result_set = ResultSet::new(columns, column_types, result_rows); + if let Some(id) = last_id { result_set.set_last_insert_id(id); }; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 2ebb428f7dd6..e4d0b439f278 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the PostgreSQL connector. //! This module is only available with the `postgresql` feature. mod defaults; + pub(crate) mod error; pub(crate) mod url; diff --git a/quaint/src/connector/postgres/native/column_type.rs b/quaint/src/connector/postgres/native/column_type.rs new file mode 100644 index 000000000000..9a3f441d4f72 --- /dev/null +++ b/quaint/src/connector/postgres/native/column_type.rs @@ -0,0 +1,129 @@ +use crate::connector::ColumnType; + +use std::borrow::Cow; +use tokio_postgres::types::{Kind as PostgresKind, Type as PostgresType}; + +macro_rules! create_pg_mapping { + ( + $($key:ident($typ: ty) => [$($value:ident),+]),* $(,)? + $([$pg_only_key:ident => $column_type_mapping:ident]),* + ) => { + // Generate PGColumnType enums + $( + concat_idents::concat_idents!(enum_name = PGColumnType, $key { + #[derive(Debug)] + #[allow(non_camel_case_types)] + #[allow(clippy::upper_case_acronyms)] + pub(crate) enum enum_name { + $($value,)* + } + }); + )* + + // Generate validators + $( + concat_idents::concat_idents!(struct_name = PGColumnValidator, $key { + #[derive(Debug)] + #[allow(non_camel_case_types)] + pub struct struct_name; + + impl struct_name { + #[inline] + #[allow(clippy::extra_unused_lifetimes)] + pub fn read<'a>(&self, val: $typ) -> $typ { + val + } + } + }); + )* + + pub(crate) enum PGColumnType { + $( + $key( + concat_idents::concat_idents!(variant = PGColumnType, $key, { variant }), + concat_idents::concat_idents!(enum_name = PGColumnValidator, $key, { enum_name }) + ), + )* + $($pg_only_key(concat_idents::concat_idents!(enum_name = PGColumnValidator, $column_type_mapping, { enum_name })),)* + } + + impl PGColumnType { + /// Takes a Postgres type and returns the corresponding ColumnType + #[deny(unreachable_patterns)] + pub(crate) fn from_pg_type(ty: &PostgresType) -> PGColumnType { + match ty { + $( + $( + &PostgresType::$value => PGColumnType::$key( + concat_idents::concat_idents!(variant = PGColumnType, $key, { variant::$value }), + concat_idents::concat_idents!(enum_name = PGColumnValidator, $key, { enum_name }), + ), + )* + )* + ref x => match x.kind() { + PostgresKind::Enum => PGColumnType::Enum(PGColumnValidatorText), + PostgresKind::Array(inner) => match inner.kind() { + PostgresKind::Enum => PGColumnType::EnumArray(PGColumnValidatorTextArray), + _ => PGColumnType::UnknownArray(PGColumnValidatorTextArray), + }, + _ => PGColumnType::Unknown(PGColumnValidatorText), + }, + } + } + } + + impl From for ColumnType { + fn from(ty: PGColumnType) -> ColumnType { + match ty { + $( + PGColumnType::$key(..) => ColumnType::$key, + )* + $( + PGColumnType::$pg_only_key(..) => ColumnType::$column_type_mapping, + )* + } + } + } + }; +} + +// Create a mapping between Postgres types and ColumnType and ensures there's a single source of truth. +// ColumnType() => [PostgresType(s)...] +create_pg_mapping! { + Boolean(Option) => [BOOL], + Int32(Option) => [INT2, INT4], + Int64(Option) => [INT8, OID], + Float(Option) => [FLOAT4], + Double(Option) => [FLOAT8], + Bytes(Option>) => [BYTEA], + Numeric(Option) => [NUMERIC, MONEY], + DateTime(Option>) => [TIMESTAMP, TIMESTAMPTZ], + Date(Option) => [DATE], + Time(Option) => [TIME, TIMETZ], + Text(Option>) => [INET, CIDR, BIT, VARBIT], + Uuid(Option) => [UUID], + Json(Option) => [JSON, JSONB], + Xml(Option>) => [XML], + Char(Option) => [CHAR], + + BooleanArray(impl Iterator>) => [BOOL_ARRAY], + Int32Array(impl Iterator>) => [INT2_ARRAY, INT4_ARRAY], + Int64Array(impl Iterator>) => [INT8_ARRAY, OID_ARRAY], + FloatArray(impl Iterator>) => [FLOAT4_ARRAY], + DoubleArray(impl Iterator>) => [FLOAT8_ARRAY], + BytesArray(impl Iterator>>) => [BYTEA_ARRAY], + NumericArray(impl Iterator>) => [NUMERIC_ARRAY, MONEY_ARRAY], + DateTimeArray(impl Iterator>>) => [TIMESTAMP_ARRAY, TIMESTAMPTZ_ARRAY], + DateArray(impl Iterator>) => [DATE_ARRAY], + TimeArray(impl Iterator>) => [TIME_ARRAY, TIMETZ_ARRAY], + TextArray(impl Iterator>>) => [TEXT_ARRAY, NAME_ARRAY, VARCHAR_ARRAY, INET_ARRAY, CIDR_ARRAY, BIT_ARRAY, VARBIT_ARRAY, XML_ARRAY], + UuidArray(impl Iterator>) => [UUID_ARRAY], + JsonArray(impl Iterator>) => [JSON_ARRAY, JSONB_ARRAY], + + // For the cases where the Postgres type is not directly mappable to ColumnType, use the following: + // [PGColumnType => ColumnType] + [Enum => Text], + [EnumArray => TextArray], + [UnknownArray => TextArray], + [Unknown => Text] +} diff --git a/quaint/src/connector/postgres/native/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs index 4479eed69c69..fefe62a96d9a 100644 --- a/quaint/src/connector/postgres/native/conversion.rs +++ b/quaint/src/connector/postgres/native/conversion.rs @@ -4,8 +4,11 @@ use crate::{ ast::{Value, ValueType}, connector::queryable::{GetRow, ToColumnNames}, error::{Error, ErrorKind}, + prelude::EnumVariant, }; +use super::column_type::*; + use bigdecimal::{num_bigint::BigInt, BigDecimal, FromPrimitive, ToPrimitive}; use bit_vec::BitVec; use bytes::BytesMut; @@ -13,7 +16,7 @@ use chrono::{DateTime, NaiveDateTime, Utc}; pub(crate) use decimal::DecimalWrapper; use postgres_types::{FromSql, ToSql, WrongType}; -use std::{convert::TryFrom, error::Error as StdError}; +use std::{borrow::Cow, convert::TryFrom, error::Error as StdError}; use tokio_postgres::{ types::{self, IsNull, Kind, Type as PostgresType}, Row as PostgresRow, Statement as PostgresStatement, @@ -162,411 +165,528 @@ impl<'a> FromSql<'a> for NaiveMoney { impl GetRow for PostgresRow { fn get_result_row(&self) -> crate::Result>> { fn convert(row: &PostgresRow, i: usize) -> crate::Result> { - let result = match *row.columns()[i].type_() { - PostgresType::BOOL => ValueType::Boolean(row.try_get(i)?).into_value(), - PostgresType::INT2 => match row.try_get(i)? { - Some(val) => { - let val: i16 = val; - Value::int32(val) - } - None => Value::null_int32(), - }, - PostgresType::INT4 => match row.try_get(i)? { - Some(val) => { - let val: i32 = val; - Value::int32(val) - } - None => Value::null_int32(), - }, - PostgresType::INT8 => match row.try_get(i)? { - Some(val) => { - let val: i64 = val; - Value::int64(val) - } - None => Value::null_int64(), - }, - PostgresType::FLOAT4 => match row.try_get(i)? { - Some(val) => { - let val: f32 = val; - Value::float(val) - } - None => Value::null_float(), - }, - PostgresType::FLOAT8 => match row.try_get(i)? { - Some(val) => { - let val: f64 = val; - Value::double(val) - } - None => Value::null_double(), + let pg_ty = row.columns()[i].type_(); + let column_type = PGColumnType::from_pg_type(pg_ty); + + // This convoluted nested enum is macro-generated to ensure we have a single source of truth for + // the mapping between Postgres types and ColumnType. The macro is in `./column_type.rs`. + // PGColumnValidator are used to softly ensure that the correct `ValueType` variants are created. + // If you ever add a new type or change some mapping, please ensure you pass the data through `v.read()`. + let result = match column_type { + PGColumnType::Boolean(ty, v) => match ty { + PGColumnTypeBoolean::BOOL => ValueType::Boolean(v.read(row.try_get(i)?)), }, - PostgresType::BYTEA => match row.try_get(i)? { - Some(val) => { - let val: &[u8] = val; - Value::bytes(val.to_owned()) - } - None => Value::null_bytes(), - }, - PostgresType::BYTEA_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec>> = val; - let byteas = val.into_iter().map(|b| ValueType::Bytes(b.map(Into::into))); + PGColumnType::Int32(ty, v) => match ty { + PGColumnTypeInt32::INT2 => { + let val: Option = row.try_get(i)?; - Value::array(byteas) + ValueType::Int32(v.read(val.map(i32::from))) } - None => Value::null_array(), - }, - PostgresType::NUMERIC => { - let dw: Option = row.try_get(i)?; + PGColumnTypeInt32::INT4 => { + let val: Option = row.try_get(i)?; - ValueType::Numeric(dw.map(|dw| dw.0)).into_value() - } - PostgresType::MONEY => match row.try_get(i)? { - Some(val) => { - let val: NaiveMoney = val; - Value::numeric(val.0) + ValueType::Int32(v.read(val)) } - None => Value::null_numeric(), }, - PostgresType::TIMESTAMP => match row.try_get(i)? { - Some(val) => { - let ts: NaiveDateTime = val; - let dt = DateTime::::from_naive_utc_and_offset(ts, Utc); - Value::datetime(dt) + PGColumnType::Int64(ty, v) => match ty { + PGColumnTypeInt64::INT8 => { + let val = v.read(row.try_get(i)?); + + ValueType::Int64(val) } - None => Value::null_datetime(), - }, - PostgresType::TIMESTAMPTZ => match row.try_get(i)? { - Some(val) => { - let ts: DateTime = val; - Value::datetime(ts) + PGColumnTypeInt64::OID => { + let val: Option = row.try_get(i)?; + + ValueType::Int64(v.read(val.map(i64::from))) } - None => Value::null_datetime(), }, - PostgresType::DATE => match row.try_get(i)? { - Some(val) => Value::date(val), - None => Value::null_date(), + PGColumnType::Float(ty, v) => match ty { + PGColumnTypeFloat::FLOAT4 => ValueType::Float(v.read(row.try_get(i)?)), }, - PostgresType::TIME => match row.try_get(i)? { - Some(val) => Value::time(val), - None => Value::null_time(), + PGColumnType::Double(ty, v) => match ty { + PGColumnTypeDouble::FLOAT8 => ValueType::Double(v.read(row.try_get(i)?)), }, - PostgresType::TIMETZ => match row.try_get(i)? { - Some(val) => { - let time: TimeTz = val; - Value::time(time.0) + PGColumnType::Bytes(ty, v) => match ty { + PGColumnTypeBytes::BYTEA => { + let val: Option<&[u8]> = row.try_get(i)?; + let val = val.map(ToOwned::to_owned).map(Cow::Owned); + + ValueType::Bytes(v.read(val)) } - None => Value::null_time(), }, - PostgresType::UUID => match row.try_get(i)? { - Some(val) => { - let val: Uuid = val; - Value::uuid(val) + PGColumnType::Text(ty, v) => match ty { + PGColumnTypeText::INET | PGColumnTypeText::CIDR => { + let val: Option = row.try_get(i)?; + let val = val.map(|val| val.to_string()).map(Cow::from); + + ValueType::Text(v.read(val)) } - None => ValueType::Uuid(None).into_value(), - }, - PostgresType::UUID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let val = val.into_iter().map(ValueType::Uuid); + PGColumnTypeText::VARBIT | PGColumnTypeText::BIT => { + let val: Option = row.try_get(i)?; + let val_str = val.map(|val| bits_to_string(&val)).transpose()?.map(Cow::Owned); - Value::array(val) + ValueType::Text(v.read(val_str)) } - None => Value::null_array(), }, - PostgresType::JSON | PostgresType::JSONB => ValueType::Json(row.try_get(i)?).into_value(), - PostgresType::INT2_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(|i| ValueType::Int32(i.map(|i| i as i32))); + PGColumnType::Char(ty, v) => match ty { + PGColumnTypeChar::CHAR => { + let val: Option = row.try_get(i)?; + let val = val.map(|val| (val as u8) as char); - Value::array(ints) + ValueType::Char(v.read(val)) } - None => Value::null_array(), }, - PostgresType::INT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(ValueType::Int32); + PGColumnType::Numeric(ty, v) => match ty { + PGColumnTypeNumeric::NUMERIC => { + let dw: Option = row.try_get(i)?; + let val = dw.map(|dw| dw.0); - Value::array(ints) + ValueType::Numeric(v.read(val)) } - None => Value::null_array(), - }, - PostgresType::INT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(ValueType::Int64); + PGColumnTypeNumeric::MONEY => { + let val: Option = row.try_get(i)?; - Value::array(ints) + ValueType::Numeric(v.read(val.map(|val| val.0))) } - None => Value::null_array(), }, - PostgresType::FLOAT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let floats = val.into_iter().map(ValueType::Float); + PGColumnType::DateTime(ty, v) => match ty { + PGColumnTypeDateTime::TIMESTAMP => { + let ts: Option = row.try_get(i)?; + let dt = ts.map(|ts| DateTime::::from_naive_utc_and_offset(ts, Utc)); - Value::array(floats) + ValueType::DateTime(v.read(dt)) } - None => Value::null_array(), - }, - PostgresType::FLOAT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let floats = val.into_iter().map(ValueType::Double); + PGColumnTypeDateTime::TIMESTAMPTZ => { + let ts: Option> = row.try_get(i)?; - Value::array(floats) + ValueType::DateTime(v.read(ts)) } - None => Value::null_array(), }, - PostgresType::BOOL_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let bools = val.into_iter().map(ValueType::Boolean); + PGColumnType::Date(ty, v) => match ty { + PGColumnTypeDate::DATE => ValueType::Date(v.read(row.try_get(i)?)), + }, + PGColumnType::Time(ty, v) => match ty { + PGColumnTypeTime::TIME => ValueType::Time(v.read(row.try_get(i)?)), + PGColumnTypeTime::TIMETZ => { + let val: Option = row.try_get(i)?; - Value::array(bools) + ValueType::Time(v.read(val.map(|val| val.0))) } - None => Value::null_array(), }, - PostgresType::TIMESTAMP_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - - let dates = val.into_iter().map(|dt| { - ValueType::DateTime(dt.map(|dt| DateTime::::from_naive_utc_and_offset(dt, Utc))) - }); + PGColumnType::Json(ty, v) => match ty { + PGColumnTypeJson::JSON | PGColumnTypeJson::JSONB => ValueType::Json(v.read(row.try_get(i)?)), + }, + PGColumnType::Xml(ty, v) => match ty { + PGColumnTypeXml::XML => { + let val: Option = row.try_get(i)?; - Value::array(dates) + ValueType::Xml(v.read(val.map(|val| Cow::Owned(val.0)))) } - None => Value::null_array(), }, - PostgresType::NUMERIC_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; + PGColumnType::Uuid(ty, v) => match ty { + PGColumnTypeUuid::UUID => ValueType::Uuid(v.read(row.try_get(i)?)), + }, + PGColumnType::Int32Array(ty, v) => match ty { + PGColumnTypeInt32Array::INT2_ARRAY => { + let vals: Option>> = row.try_get(i)?; - let decimals = val - .into_iter() - .map(|dec| ValueType::Numeric(dec.map(|dec| dec.0.to_string().parse().unwrap()))); + match vals { + Some(vals) => { + let ints = vals.into_iter().map(|val| val.map(i32::from)); - Value::array(decimals) + ValueType::Array(Some( + v.read(ints).map(ValueType::Int32).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } + } + PGColumnTypeInt32Array::INT4_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Int32) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TEXT_ARRAY | PostgresType::NAME_ARRAY | PostgresType::VARCHAR_ARRAY => { - match row.try_get(i)? { - Some(val) => { - let strings: Vec> = val; - - Value::array(strings.into_iter().map(|s| s.map(|s| s.to_string()))) + PGColumnType::Int64Array(ty, v) => match ty { + PGColumnTypeInt64Array::INT8_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Int64) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), } - None => Value::null_array(), } - } - PostgresType::MONEY_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let nums = val.into_iter().map(|num| ValueType::Numeric(num.map(|num| num.0))); + PGColumnTypeInt64Array::OID_ARRAY => { + let vals: Option>> = row.try_get(i)?; - Value::array(nums) - } - None => Value::null_array(), - }, - PostgresType::OID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let nums = val.into_iter().map(|oid| ValueType::Int64(oid.map(|oid| oid as i64))); + match vals { + Some(vals) => { + let oids = vals.into_iter().map(|oid| oid.map(i64::from)); - Value::array(nums) + ValueType::Array(Some( + v.read(oids).map(ValueType::Int64).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec>> = val; - let dates = val.into_iter().map(ValueType::DateTime); - - Value::array(dates) + PGColumnType::FloatArray(ty, v) => match ty { + PGColumnTypeFloatArray::FLOAT4_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Float) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::DATE_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let dates = val.into_iter().map(ValueType::Date); - - Value::array(dates) + PGColumnType::DoubleArray(ty, v) => match ty { + PGColumnTypeDoubleArray::FLOAT8_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Double) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TIME_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let times = val.into_iter().map(ValueType::Time); - - Value::array(times) + PGColumnType::TextArray(ty, v) => match ty { + PGColumnTypeTextArray::TEXT_ARRAY + | PGColumnTypeTextArray::NAME_ARRAY + | PGColumnTypeTextArray::VARCHAR_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let strings = vals.into_iter().map(|s| s.map(ToOwned::to_owned).map(Cow::Owned)); + + ValueType::Array(Some( + v.read(strings) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::TIMETZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let timetzs = val.into_iter().map(|time| ValueType::Time(time.map(|time| time.0))); + PGColumnTypeTextArray::INET_ARRAY | PGColumnTypeTextArray::CIDR_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let addrs = vals + .into_iter() + .map(|ip| ip.as_ref().map(ToString::to_string).map(Cow::Owned)); - Value::array(timetzs) + ValueType::Array(Some( + v.read(addrs).map(ValueType::Text).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::JSON_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let jsons = val.into_iter().map(ValueType::Json); + PGColumnTypeTextArray::BIT_ARRAY | PGColumnTypeTextArray::VARBIT_ARRAY => { + let vals: Option>> = row.try_get(i)?; - Value::array(jsons) + match vals { + Some(vals) => { + let vals = vals + .into_iter() + .map(|bits| bits.map(|bits| bits_to_string(&bits).map(Cow::Owned)).transpose()) + .collect::>>()?; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::JSONB_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let jsons = val.into_iter().map(ValueType::Json); + PGColumnTypeTextArray::XML_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let xmls = vals.into_iter().map(|xml| xml.map(|xml| xml.0).map(Cow::Owned)); - Value::array(jsons) + ValueType::Array(Some( + v.read(xmls).map(ValueType::Text).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::OID => match row.try_get(i)? { - Some(val) => { - let val: u32 = val; - Value::int64(val) + PGColumnType::BytesArray(ty, v) => match ty { + PGColumnTypeBytesArray::BYTEA_ARRAY => { + let vals: Option>>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(|b| b.map(Cow::Owned)) + .map(ValueType::Bytes) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_int64(), }, - PostgresType::CHAR => match row.try_get(i)? { - Some(val) => { - let val: i8 = val; - Value::character((val as u8) as char) + PGColumnType::BooleanArray(ty, v) => match ty { + PGColumnTypeBooleanArray::BOOL_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Boolean) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_character(), }, - PostgresType::INET | PostgresType::CIDR => match row.try_get(i)? { - Some(val) => { - let val: std::net::IpAddr = val; - Value::text(val.to_string()) + PGColumnType::NumericArray(ty, v) => match ty { + PGColumnTypeNumericArray::NUMERIC_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let decimals = vals.into_iter().map(|dec| dec.map(|dec| dec.0)); + + ValueType::Array(Some( + v.read(decimals.into_iter()) + .map(ValueType::Numeric) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_text(), - }, - PostgresType::INET_ARRAY | PostgresType::CIDR_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let addrs = val - .into_iter() - .map(|ip| ValueType::Text(ip.map(|ip| ip.to_string().into()))); - - Value::array(addrs) + PGColumnTypeNumericArray::MONEY_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let nums = vals.into_iter().map(|num| num.map(|num| num.0)); + + ValueType::Array(Some( + v.read(nums.into_iter()) + .map(ValueType::Numeric) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::BIT | PostgresType::VARBIT => match row.try_get(i)? { - Some(val) => { - let val: BitVec = val; - Value::text(bits_to_string(&val)?) + PGColumnType::JsonArray(ty, v) => match ty { + PGColumnTypeJsonArray::JSON_ARRAY | PGColumnTypeJsonArray::JSONB_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Json) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_text(), }, - PostgresType::BIT_ARRAY | PostgresType::VARBIT_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - val.into_iter() - .map(|bits| match bits { - Some(bits) => bits_to_string(&bits).map(|s| ValueType::Text(Some(s.into()))), - None => Ok(ValueType::Text(None)), - }) - .collect::>>() - .map(Value::array)? - } - None => Value::null_array(), + PGColumnType::UuidArray(ty, v) => match ty { + PGColumnTypeUuidArray::UUID_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Uuid) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - PostgresType::XML => match row.try_get(i)? { - Some(val) => { - let val: XmlString = val; - Value::xml(val.0) - } - None => Value::null_xml(), + PGColumnType::DateTimeArray(ty, v) => match ty { + PGColumnTypeDateTimeArray::TIMESTAMP_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + let dates = vals + .into_iter() + .map(|dt| dt.map(|dt| DateTime::::from_naive_utc_and_offset(dt, Utc))); + + ValueType::Array(Some( + v.read(dates) + .map(ValueType::DateTime) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, + PGColumnTypeDateTimeArray::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec>> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::DateTime) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - PostgresType::XML_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let xmls = val.into_iter().map(|xml| xml.map(|xml| xml.0)); - - Value::array(xmls) - } - None => Value::null_array(), + PGColumnType::DateArray(ty, v) => match ty { + PGColumnTypeDateArray::DATE_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Date) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - ref x => match x.kind() { - Kind::Enum => match row.try_get(i)? { + PGColumnType::TimeArray(ty, v) => match ty { + PGColumnTypeTimeArray::TIME_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Time) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, + PGColumnTypeTimeArray::TIMETZ_ARRAY => match row.try_get(i)? { Some(val) => { - let val: EnumString = val; - - Value::enum_variant(val.value) + let val: Vec> = val; + let timetzs = val.into_iter().map(|time| time.map(|time| time.0)); + + ValueType::Array(Some( + v.read(timetzs.into_iter()) + .map(ValueType::Time) + .map(ValueType::into_value) + .collect(), + )) } - None => Value::null_enum(), + None => ValueType::Array(None), }, - Kind::Array(inner) => match inner.kind() { - Kind::Enum => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let variants = val - .into_iter() - .map(|x| ValueType::Enum(x.map(|x| x.value.into()), None)); - - Ok(Value::array(variants)) - } - None => Ok(Value::null_array()), - }, - _ => match row.try_get(i) { - Ok(Some(val)) => { - let val: Vec> = val; - let strings = val.into_iter().map(|str| ValueType::Text(str.map(Into::into))); - - Ok(Value::array(strings)) - } - Ok(None) => Ok(Value::null_array()), - Err(err) => { - if err.source().map(|err| err.is::()).unwrap_or(false) { - let kind = ErrorKind::UnsupportedColumnType { - column_type: x.to_string(), - }; - - return Err(Error::builder(kind).build()); - } else { - Err(err) - } - } - }, - }?, - _ => match row.try_get(i) { - Ok(Some(val)) => { - let val: String = val; + }, + PGColumnType::EnumArray(v) => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let enums = vals.into_iter().map(|val| val.map(|val| Cow::Owned(val.value))); + + ValueType::Array(Some( + v.read(enums) + .map(|variant| ValueType::Enum(variant.map(EnumVariant::new), None)) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } + } + PGColumnType::Enum(v) => { + let val: Option = row.try_get(i)?; + let enum_variant = v.read(val.map(|x| Cow::Owned(x.value))); - Ok(Value::text(val)) + ValueType::Enum(enum_variant.map(EnumVariant::new), None) + } + PGColumnType::UnknownArray(v) => match row.try_get(i) { + Ok(Some(vals)) => { + let vals: Vec> = vals; + let strings = vals.into_iter().map(|str| str.map(Cow::Owned)); + + Ok(ValueType::Array(Some( + v.read(strings.into_iter()) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + ))) + } + Ok(None) => Ok(ValueType::Array(None)), + Err(err) => { + if err.source().map(|err| err.is::()).unwrap_or(false) { + let kind = ErrorKind::UnsupportedColumnType { + column_type: pg_ty.to_string(), + }; + + return Err(Error::builder(kind).build()); + } else { + Err(err) } - Ok(None) => Ok(Value::from(ValueType::Text(None))), - Err(err) => { - if err.source().map(|err| err.is::()).unwrap_or(false) { - let kind = ErrorKind::UnsupportedColumnType { - column_type: x.to_string(), - }; - - return Err(Error::builder(kind).build()); - } else { - Err(err) - } + } + }?, + PGColumnType::Unknown(v) => match row.try_get(i) { + Ok(Some(val)) => { + let val: String = val; + + Ok(ValueType::Text(v.read(Some(Cow::Owned(val))))) + } + Ok(None) => Ok(ValueType::Text(None)), + Err(err) => { + if err.source().map(|err| err.is::()).unwrap_or(false) { + let kind = ErrorKind::UnsupportedColumnType { + column_type: pg_ty.to_string(), + }; + + return Err(Error::builder(kind).build()); + } else { + Err(err) } - }?, - }, + } + }?, }; - Ok(result) + Ok(result.into_value()) } let num_columns = self.columns().len(); diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 85bd34066b90..8768b4a60a5f 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,12 +1,13 @@ //! Definitions for the Postgres connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `postgresql-native` feature. +pub(crate) mod column_type; mod conversion; mod error; pub(crate) use crate::connector::postgres::url::PostgresUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; -use crate::connector::{timeout, IsolationLevel, Transaction}; +use crate::connector::{timeout, ColumnType, IsolationLevel, Transaction}; use crate::error::NativeErrorKind; use crate::{ @@ -16,6 +17,7 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; +use column_type::PGColumnType; use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; @@ -416,7 +418,13 @@ impl Queryable for PostgreSql { .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) .await?; - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); for row in rows { result.rows.push(row.get_result_row()?); @@ -442,11 +450,17 @@ impl Queryable for PostgreSql { return Err(Error::builder(kind).build()); } + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); let rows = self .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) .await?; - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); for row in rows { result.rows.push(row.get_result_row()?); diff --git a/quaint/src/connector/result_set.rs b/quaint/src/connector/result_set.rs index 7592a85ea15c..7e7d045f9842 100644 --- a/quaint/src/connector/result_set.rs +++ b/quaint/src/connector/result_set.rs @@ -8,19 +8,23 @@ use crate::{ast::Value, error::*}; use serde_json::Map; use std::sync::Arc; +use super::ColumnType; + /// Encapsulates a set of results and their respective column names. #[derive(Debug, Default)] pub struct ResultSet { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) rows: Vec>>, pub(crate) last_insert_id: Option, } impl ResultSet { /// Creates a new instance, bound to the given column names and result rows. - pub fn new(names: Vec, rows: Vec>>) -> Self { + pub fn new(names: Vec, types: Vec, rows: Vec>>) -> Self { Self { columns: Arc::new(names), + types, rows, last_insert_id: None, } @@ -61,6 +65,7 @@ impl ResultSet { self.rows.get(index).map(|row| ResultRowRef { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }) } @@ -75,9 +80,14 @@ impl ResultSet { pub fn iter(&self) -> ResultSetIterator<'_> { ResultSetIterator { columns: self.columns.clone(), + types: self.types.clone(), internal_iterator: self.rows.iter(), } } + + pub fn types(&self) -> &[ColumnType] { + &self.types + } } impl IntoIterator for ResultSet { @@ -87,6 +97,7 @@ impl IntoIterator for ResultSet { fn into_iter(self) -> Self::IntoIter { ResultSetIntoIterator { columns: self.columns, + types: self.types.clone(), internal_iterator: self.rows.into_iter(), } } @@ -96,6 +107,7 @@ impl IntoIterator for ResultSet { /// Might become lazy one day. pub struct ResultSetIntoIterator { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) internal_iterator: std::vec::IntoIter>>, } @@ -107,6 +119,7 @@ impl Iterator for ResultSetIntoIterator { Some(row) => Some(ResultRow { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }), None => None, } @@ -115,6 +128,7 @@ impl Iterator for ResultSetIntoIterator { pub struct ResultSetIterator<'a> { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) internal_iterator: std::slice::Iter<'a, Vec>>, } @@ -126,6 +140,7 @@ impl<'a> Iterator for ResultSetIterator<'a> { Some(row) => Some(ResultRowRef { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }), None => None, } diff --git a/quaint/src/connector/result_set/result_row.rs b/quaint/src/connector/result_set/result_row.rs index a6a5b55c3c62..ae2b068a5814 100644 --- a/quaint/src/connector/result_set/result_row.rs +++ b/quaint/src/connector/result_set/result_row.rs @@ -1,5 +1,6 @@ use crate::{ ast::Value, + connector::ColumnType, error::{Error, ErrorKind}, }; use std::sync::Arc; @@ -9,6 +10,7 @@ use std::sync::Arc; #[derive(Debug, PartialEq)] pub struct ResultRow { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) values: Vec>, } @@ -38,6 +40,7 @@ impl IntoIterator for ResultRow { #[derive(Debug, PartialEq)] pub struct ResultRowRef<'a> { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) values: &'a Vec>, } @@ -70,6 +73,7 @@ impl ResultRow { ResultRowRef { columns: Arc::clone(&self.columns), values: &self.values, + types: self.types.clone(), } } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 05f073d9c34e..3a30b38975d9 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the SQLite connector. //! This module is only available with the `sqlite` feature. mod defaults; + pub(crate) mod error; mod ffi; pub(crate) mod params; diff --git a/quaint/src/connector/sqlite/native/column_type.rs b/quaint/src/connector/sqlite/native/column_type.rs new file mode 100644 index 000000000000..e8f1291a3a15 --- /dev/null +++ b/quaint/src/connector/sqlite/native/column_type.rs @@ -0,0 +1,14 @@ +use rusqlite::Column; + +use crate::connector::{ColumnType, TypeIdentifier}; + +impl From<&Column<'_>> for ColumnType { + fn from(value: &Column) -> Self { + if value.is_float() { + // Sqlite always returns Double for floats + ColumnType::Double + } else { + ColumnType::from_type_identifier(value) + } + } +} diff --git a/quaint/src/connector/sqlite/native/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs index afd0145fade8..b06be6487acd 100644 --- a/quaint/src/connector/sqlite/native/conversion.rs +++ b/quaint/src/connector/sqlite/native/conversion.rs @@ -16,7 +16,7 @@ use rusqlite::{ use chrono::TimeZone; -impl TypeIdentifier for Column<'_> { +impl TypeIdentifier for &Column<'_> { fn is_real(&self) -> bool { match self.decl_type() { Some(n) if n.starts_with("DECIMAL") => true, @@ -82,7 +82,6 @@ impl TypeIdentifier for Column<'_> { ) } - #[cfg(feature = "mysql")] fn is_time(&self) -> bool { false } @@ -119,12 +118,10 @@ impl TypeIdentifier for Column<'_> { matches!(self.decl_type(), Some("BOOLEAN") | Some("boolean")) } - #[cfg(feature = "mysql")] fn is_json(&self) -> bool { false } - #[cfg(feature = "mysql")] fn is_enum(&self) -> bool { false } @@ -146,8 +143,7 @@ impl<'a> GetRow for SqliteRow<'a> { c if c.is_int64() => Value::null_int64(), c if c.is_text() => Value::null_text(), c if c.is_bytes() => Value::null_bytes(), - c if c.is_float() => Value::null_float(), - c if c.is_double() => Value::null_double(), + c if c.is_float() || c.is_double() => Value::null_double(), c if c.is_real() => Value::null_numeric(), c if c.is_datetime() => Value::null_datetime(), c if c.is_date() => Value::null_date(), @@ -251,7 +247,9 @@ impl<'a> ToSql for Value<'a> { let value = match &self.typed { ValueType::Int32(integer) => integer.map(ToSqlOutput::from), ValueType::Int64(integer) => integer.map(ToSqlOutput::from), - ValueType::Float(float) => float.map(|f| f as f64).map(ToSqlOutput::from), + ValueType::Float(float) => { + float.map(|float| ToSqlOutput::from(float.to_string().parse::().expect("f32 is not a f64."))) + } ValueType::Double(double) => double.map(ToSqlOutput::from), ValueType::Text(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), ValueType::Enum(cow, _) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 58ef03799e2f..92dd7cd7bf2b 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,11 +1,12 @@ //! Definitions for the SQLite connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `sqlite-native` feature. +mod column_type; mod conversion; mod error; -use crate::connector::sqlite::params::SqliteParams; use crate::connector::IsolationLevel; +use crate::connector::{sqlite::params::SqliteParams, ColumnType}; pub use rusqlite::{params_from_iter, version as sqlite_version}; @@ -104,8 +105,9 @@ impl Queryable for Sqlite { let mut stmt = client.prepare_cached(sql)?; + let col_types = stmt.columns().iter().map(ColumnType::from).collect::>(); let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + let mut result = ResultSet::new(rows.to_column_names(), col_types, Vec::new()); while let Some(row) = rows.next()? { result.rows.push(row.get_result_row()?); diff --git a/quaint/src/connector/type_identifier.rs b/quaint/src/connector/type_identifier.rs index ce27ea89a404..9fcc46f61c1c 100644 --- a/quaint/src/connector/type_identifier.rs +++ b/quaint/src/connector/type_identifier.rs @@ -5,15 +5,12 @@ pub(crate) trait TypeIdentifier { fn is_int32(&self) -> bool; fn is_int64(&self) -> bool; fn is_datetime(&self) -> bool; - #[cfg(feature = "mysql")] fn is_time(&self) -> bool; fn is_date(&self) -> bool; fn is_text(&self) -> bool; fn is_bytes(&self) -> bool; fn is_bool(&self) -> bool; - #[cfg(feature = "mysql")] fn is_json(&self) -> bool; - #[cfg(feature = "mysql")] fn is_enum(&self) -> bool; fn is_null(&self) -> bool; } diff --git a/quaint/src/macros.rs b/quaint/src/macros.rs index cfb52bc0c6e1..9921359669c9 100644 --- a/quaint/src/macros.rs +++ b/quaint/src/macros.rs @@ -1,3 +1,5 @@ +use crate::{connector::ColumnType, Value}; + /// Convert given set of tuples into `Values`. /// /// ```rust @@ -173,7 +175,7 @@ macro_rules! expression { /// A test-generator to test types in the defined database. #[cfg(test)] macro_rules! test_type { - ($name:ident($db:ident, $sql_type:literal, $(($input:expr, $output:expr)),+ $(,)?)) => { + ($name:ident($db:ident, $sql_type:literal, $col_type:expr, $(($input:expr, $output:expr)),+ $(,)?)) => { paste::item! { #[test] fn [< test_type_ $name >] () -> crate::Result<()> { @@ -198,7 +200,9 @@ macro_rules! test_type { let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = setup.conn().select(select).await?.into_single()?; + assert_eq!($col_type, res.types[0]); assert_eq!(Some(&output), res.at(0)); + assert_matching_value_and_column_type(&$col_type, res.at(0).unwrap()); )+ Result::<(), crate::error::Error>::Ok(()) @@ -209,7 +213,7 @@ macro_rules! test_type { } }; - ($name:ident($db:ident, $sql_type:literal, $($value:expr),+ $(,)?)) => { + ($name:ident($db:ident, $sql_type:literal, $col_type:expr, $($value:expr),+ $(,)?)) => { paste::item! { #[test] fn [< test_type_ $name >] () -> crate::Result<()> { @@ -232,7 +236,9 @@ macro_rules! test_type { let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = setup.conn().select(select).await?.into_single()?; + assert_eq!($col_type, res.types[0]); assert_eq!(Some(&value), res.at(0)); + assert_matching_value_and_column_type(&$col_type, &value); )+ Result::<(), crate::error::Error>::Ok(()) @@ -243,3 +249,12 @@ macro_rules! test_type { } }; } + +#[allow(dead_code)] +pub(crate) fn assert_matching_value_and_column_type(col_type: &ColumnType, value: &Value) { + let inferred_column_type = ColumnType::from(&value.typed); + + if !inferred_column_type.is_unknown() { + assert_eq!(col_type, &inferred_column_type); + } +} diff --git a/quaint/src/prelude.rs b/quaint/src/prelude.rs index 1fe867ccd4cc..2984e49fdc94 100644 --- a/quaint/src/prelude.rs +++ b/quaint/src/prelude.rs @@ -1,7 +1,7 @@ //! A "prelude" for users of the `quaint` crate. pub use crate::ast::*; pub use crate::connector::{ - ConnectionInfo, DefaultTransaction, ExternalConnectionInfo, Queryable, ResultRow, ResultSet, SqlFamily, + ColumnType, ConnectionInfo, DefaultTransaction, ExternalConnectionInfo, Queryable, ResultRow, ResultSet, SqlFamily, TransactionCapable, }; pub use crate::{col, val, values}; diff --git a/quaint/src/tests/types/mssql.rs b/quaint/src/tests/types/mssql.rs index ac404dd8af38..bd3ce6555a69 100644 --- a/quaint/src/tests/types/mssql.rs +++ b/quaint/src/tests/types/mssql.rs @@ -2,11 +2,14 @@ mod bigdecimal; -use crate::tests::test_api::*; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; +use std::str::FromStr; test_type!(nvarchar_limited( mssql, "NVARCHAR(10)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("ä½™"), @@ -15,6 +18,7 @@ test_type!(nvarchar_limited( test_type!(nvarchar_max( mssql, "NVARCHAR(max)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("ä½™"), @@ -24,6 +28,7 @@ test_type!(nvarchar_max( test_type!(ntext( mssql, "NTEXT", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("ä½™"), @@ -32,6 +37,7 @@ test_type!(ntext( test_type!(varchar_limited( mssql, "VARCHAR(10)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), )); @@ -39,15 +45,23 @@ test_type!(varchar_limited( test_type!(varchar_max( mssql, "VARCHAR(max)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), )); -test_type!(text(mssql, "TEXT", Value::null_text(), Value::text("foobar"))); +test_type!(text( + mssql, + "TEXT", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(tinyint( mssql, "tinyint", + ColumnType::Int32, Value::null_int32(), Value::int32(u8::MIN), Value::int32(u8::MAX), @@ -56,6 +70,7 @@ test_type!(tinyint( test_type!(smallint( mssql, "smallint", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX), @@ -64,6 +79,7 @@ test_type!(smallint( test_type!(int( mssql, "int", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX), @@ -72,27 +88,48 @@ test_type!(int( test_type!(bigint( mssql, "bigint", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(float_24(mssql, "float(24)", Value::null_float(), Value::float(1.23456),)); +test_type!(float_24( + mssql, + "float(24)", + ColumnType::Float, + Value::null_float(), + Value::float(1.23456), +)); -test_type!(real(mssql, "real", Value::null_float(), Value::float(1.123456))); +test_type!(real( + mssql, + "real", + ColumnType::Float, + Value::null_float(), + Value::float(1.123456) +)); test_type!(float_53( mssql, "float(53)", + ColumnType::Double, Value::null_double(), Value::double(1.1234567891) )); -test_type!(money(mssql, "money", Value::null_double(), Value::double(3.14))); +test_type!(money( + mssql, + "money", + ColumnType::Double, + Value::null_double(), + Value::double(3.14) +)); test_type!(smallmoney( mssql, "smallmoney", + ColumnType::Double, Value::null_double(), Value::double(3.14) )); @@ -100,6 +137,7 @@ test_type!(smallmoney( test_type!(boolean( mssql, "bit", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false), @@ -108,6 +146,7 @@ test_type!(boolean( test_type!(binary( mssql, "binary(8)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -115,6 +154,7 @@ test_type!(binary( test_type!(varbinary( mssql, "varbinary(8)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -122,6 +162,7 @@ test_type!(varbinary( test_type!(image( mssql, "image", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -129,6 +170,7 @@ test_type!(image( test_type!(date( mssql, "date", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) )); @@ -136,26 +178,67 @@ test_type!(date( test_type!(time( mssql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); -test_type!(datetime2(mssql, "datetime2", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime2( + mssql, + "datetime2", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(datetime(mssql, "datetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime( + mssql, + "datetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); + +test_type!(datetimeoffset( + mssql, + "datetimeoffset", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); + +test_type!(smalldatetime( + mssql, + "smalldatetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(datetimeoffset(mssql, "datetimeoffset", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(uuid( + mssql, + "uniqueidentifier", + ColumnType::Uuid, + Value::null_uuid(), + Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()) +)); -test_type!(smalldatetime(mssql, "smalldatetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(xml( + mssql, + "xml", + ColumnType::Xml, + Value::null_xml(), + Value::xml("bar"), +)); diff --git a/quaint/src/tests/types/mssql/bigdecimal.rs b/quaint/src/tests/types/mssql/bigdecimal.rs index 8fe3761624d2..2a2ce02350d3 100644 --- a/quaint/src/tests/types/mssql/bigdecimal.rs +++ b/quaint/src/tests/types/mssql/bigdecimal.rs @@ -1,10 +1,12 @@ use super::*; -use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; +use crate::{bigdecimal::BigDecimal, connector::ColumnType}; use std::str::FromStr; test_type!(numeric( mssql, "numeric(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("3.14")?) )); @@ -12,6 +14,7 @@ test_type!(numeric( test_type!(numeric_10_2( mssql, "numeric(10,2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -21,6 +24,7 @@ test_type!(numeric_10_2( test_type!(numeric_35_6( mssql, "numeric(35, 6)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950")?), Value::numeric(BigDecimal::from_str("3950.000000")?) @@ -102,6 +106,7 @@ test_type!(numeric_35_6( test_type!(numeric_35_2( mssql, "numeric(35, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -115,18 +120,21 @@ test_type!(numeric_35_2( test_type!(numeric_4_0( mssql, "numeric(4, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("3950")?) )); test_type!(numeric_35_0( mssql, "numeric(35, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), )); test_type!(numeric_35_1( mssql, "numeric(35, 1)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), Value::numeric(BigDecimal::from_str("79228162514264337593543950335.0")?) @@ -141,12 +149,14 @@ test_type!(numeric_35_1( test_type!(numeric_38_6( mssql, "numeric(38, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("9343234567898765456789043634999.345678")?), )); test_type!(money( mssql, "money", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), (Value::numeric(BigDecimal::from_str("3.14")?), Value::double(3.14)) )); @@ -154,6 +164,7 @@ test_type!(money( test_type!(smallmoney( mssql, "smallmoney", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), (Value::numeric(BigDecimal::from_str("3.14")?), Value::double(3.14)) )); @@ -161,6 +172,7 @@ test_type!(smallmoney( test_type!(float_24( mssql, "float(24)", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -171,6 +183,7 @@ test_type!(float_24( test_type!(real( mssql, "real", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -181,6 +194,7 @@ test_type!(real( test_type!(float_53( mssql, "float(53)", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(BigDecimal::from_str("1.123456789012345")?), diff --git a/quaint/src/tests/types/mysql.rs b/quaint/src/tests/types/mysql.rs index ade4e5d2a1f2..77444378735f 100644 --- a/quaint/src/tests/types/mysql.rs +++ b/quaint/src/tests/types/mysql.rs @@ -1,14 +1,15 @@ #![allow(clippy::approx_constant)] -use crate::tests::test_api::*; - use std::str::FromStr; use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; test_type!(tinyint( mysql, "tinyint(4)", + ColumnType::Int32, Value::null_int32(), Value::int32(i8::MIN), Value::int32(i8::MAX) @@ -17,6 +18,7 @@ test_type!(tinyint( test_type!(tinyint1( mysql, "tinyint(1)", + ColumnType::Int32, Value::int32(-1), Value::int32(1), Value::int32(0) @@ -25,6 +27,7 @@ test_type!(tinyint1( test_type!(tinyint_unsigned( mysql, "tinyint(4) unsigned", + ColumnType::Int32, Value::null_int32(), Value::int32(0), Value::int32(255) @@ -33,6 +36,7 @@ test_type!(tinyint_unsigned( test_type!(year( mysql, "year", + ColumnType::Int32, Value::null_int32(), Value::int32(1984), Value::int32(2049) @@ -41,6 +45,7 @@ test_type!(year( test_type!(smallint( mysql, "smallint", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX) @@ -49,6 +54,7 @@ test_type!(smallint( test_type!(smallint_unsigned( mysql, "smallint unsigned", + ColumnType::Int32, Value::null_int32(), Value::int32(0), Value::int32(65535) @@ -57,6 +63,7 @@ test_type!(smallint_unsigned( test_type!(mediumint( mysql, "mediumint", + ColumnType::Int32, Value::null_int32(), Value::int32(-8388608), Value::int32(8388607) @@ -65,6 +72,7 @@ test_type!(mediumint( test_type!(mediumint_unsigned( mysql, "mediumint unsigned", + ColumnType::Int64, Value::null_int64(), Value::int64(0), Value::int64(16777215) @@ -73,6 +81,7 @@ test_type!(mediumint_unsigned( test_type!(int( mysql, "int", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX) @@ -81,6 +90,7 @@ test_type!(int( test_type!(int_unsigned( mysql, "int unsigned", + ColumnType::Int64, Value::null_int64(), Value::int64(0), Value::int64(2173158296i64), @@ -90,6 +100,7 @@ test_type!(int_unsigned( test_type!(int_unsigned_not_null( mysql, "int unsigned not null", + ColumnType::Int64, Value::int64(0), Value::int64(2173158296i64), Value::int64(4294967295i64) @@ -98,6 +109,7 @@ test_type!(int_unsigned_not_null( test_type!(bigint( mysql, "bigint", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX) @@ -106,6 +118,7 @@ test_type!(bigint( test_type!(decimal( mysql, "decimal(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()) )); @@ -114,6 +127,7 @@ test_type!(decimal( test_type!(decimal_65_6( mysql, "decimal(65, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str( "93431006223456789876545678909876545678903434334567834369999.345678" )?), @@ -122,6 +136,7 @@ test_type!(decimal_65_6( test_type!(float_decimal( mysql, "float", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -132,6 +147,7 @@ test_type!(float_decimal( test_type!(double_decimal( mysql, "double", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -142,6 +158,7 @@ test_type!(double_decimal( test_type!(bit1( mysql, "bit(1)", + ColumnType::Boolean, (Value::null_bytes(), Value::null_boolean()), (Value::int32(0), Value::boolean(false)), (Value::int32(1), Value::boolean(true)), @@ -150,28 +167,77 @@ test_type!(bit1( test_type!(bit64( mysql, "bit(64)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![0, 0, 0, 0, 0, 6, 107, 58]) )); -test_type!(char(mysql, "char(255)", Value::null_text(), Value::text("foobar"))); -test_type!(float(mysql, "float", Value::null_float(), Value::float(1.12345),)); -test_type!(double(mysql, "double", Value::null_double(), Value::double(1.12314124))); +test_type!(char( + mysql, + "char(255)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(float( + mysql, + "float", + ColumnType::Float, + Value::null_float(), + Value::float(1.12345), +)); +test_type!(double( + mysql, + "double", + ColumnType::Double, + Value::null_double(), + Value::double(1.12314124) +)); test_type!(varchar( mysql, "varchar(255)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(tinytext( + mysql, + "tinytext", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(text( + mysql, + "text", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(longtext( + mysql, + "longtext", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); -test_type!(tinytext(mysql, "tinytext", Value::null_text(), Value::text("foobar"))); -test_type!(text(mysql, "text", Value::null_text(), Value::text("foobar"))); -test_type!(longtext(mysql, "longtext", Value::null_text(), Value::text("foobar"))); -test_type!(binary(mysql, "binary(5)", Value::bytes(vec![1, 2, 3, 0, 0]))); -test_type!(varbinary(mysql, "varbinary(255)", Value::bytes(vec![1, 2, 3]))); +test_type!(binary( + mysql, + "binary(5)", + ColumnType::Bytes, + Value::bytes(vec![1, 2, 3, 0, 0]) +)); +test_type!(varbinary( + mysql, + "varbinary(255)", + ColumnType::Bytes, + Value::bytes(vec![1, 2, 3]) +)); test_type!(mediumtext( mysql, "mediumtext", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); @@ -179,6 +245,7 @@ test_type!(mediumtext( test_type!(tinyblob( mysql, "tinyblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); @@ -186,6 +253,7 @@ test_type!(tinyblob( test_type!(mediumblob( mysql, "mediumblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); @@ -193,15 +261,23 @@ test_type!(mediumblob( test_type!(longblob( mysql, "longblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); -test_type!(blob(mysql, "blob", Value::null_bytes(), Value::bytes(vec![1, 2, 3]))); +test_type!(blob( + mysql, + "blob", + ColumnType::Bytes, + Value::null_bytes(), + Value::bytes(vec![1, 2, 3]) +)); test_type!(enum( mysql, "enum('pollicle_dogs','jellicle_cats')", + ColumnType::Enum, Value::null_enum(), Value::enum_variant("jellicle_cats"), Value::enum_variant("pollicle_dogs") @@ -210,28 +286,50 @@ test_type!(enum( test_type!(json( mysql, "json", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"this": "is", "a": "json", "number": 2})) )); -test_type!(date(mysql, "date", Value::null_date(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-04-20T00:00:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(date( + mysql, + "date", + ColumnType::Date, + (Value::null_date(), Value::null_date()), + ( + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()), + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) + ), + ( + Value::datetime( + chrono::DateTime::parse_from_rfc3339("2020-04-20T00:00:00Z") + .unwrap() + .with_timezone(&chrono::Utc) + ), + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) + ) +)); test_type!(time( mysql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); -test_type!(datetime(mysql, "datetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime( + mysql, + "datetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamp(mysql, "timestamp", { +test_type!(timestamp(mysql, "timestamp", ColumnType::DateTime, { let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); Value::datetime(dt.with_timezone(&chrono::Utc)) })); diff --git a/quaint/src/tests/types/postgres.rs b/quaint/src/tests/types/postgres.rs index d69a8dbb3424..99b3d67bf6e3 100644 --- a/quaint/src/tests/types/postgres.rs +++ b/quaint/src/tests/types/postgres.rs @@ -1,11 +1,13 @@ mod bigdecimal; -use crate::tests::test_api::*; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; use std::str::FromStr; test_type!(boolean( postgresql, "boolean", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false), @@ -14,6 +16,7 @@ test_type!(boolean( test_type!(boolean_array( postgresql, "boolean[]", + ColumnType::BooleanArray, Value::null_array(), Value::array(vec![ Value::boolean(true), @@ -26,6 +29,7 @@ test_type!(boolean_array( test_type!(int2( postgresql, "int2", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX), @@ -34,6 +38,7 @@ test_type!(int2( test_type!(int2_with_int64( postgresql, "int2", + ColumnType::Int32, (Value::null_int64(), Value::null_int32()), (Value::int64(i16::MIN), Value::int32(i16::MIN)), (Value::int64(i16::MAX), Value::int32(i16::MAX)) @@ -42,6 +47,7 @@ test_type!(int2_with_int64( test_type!(int2_array( postgresql, "int2[]", + ColumnType::Int32Array, Value::null_array(), Value::array(vec![ Value::int32(1), @@ -54,6 +60,7 @@ test_type!(int2_array( test_type!(int2_array_with_i64( postgresql, "int2[]", + ColumnType::Int32Array, ( Value::array(vec![ Value::int64(i16::MIN), @@ -71,6 +78,7 @@ test_type!(int2_array_with_i64( test_type!(int4( postgresql, "int4", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX), @@ -79,6 +87,7 @@ test_type!(int4( test_type!(int4_with_i64( postgresql, "int4", + ColumnType::Int32, (Value::null_int64(), Value::null_int32()), (Value::int64(i32::MIN), Value::int32(i32::MIN)), (Value::int64(i32::MAX), Value::int32(i32::MAX)) @@ -87,6 +96,7 @@ test_type!(int4_with_i64( test_type!(int4_array( postgresql, "int4[]", + ColumnType::Int32Array, Value::null_array(), Value::array(vec![ Value::int32(i32::MIN), @@ -98,6 +108,7 @@ test_type!(int4_array( test_type!(int4_array_with_i64( postgresql, "int4[]", + ColumnType::Int32Array, ( Value::array(vec![ Value::int64(i32::MIN), @@ -115,6 +126,7 @@ test_type!(int4_array_with_i64( test_type!(int8( postgresql, "int8", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), @@ -123,6 +135,7 @@ test_type!(int8( test_type!(int8_array( postgresql, "int8[]", + ColumnType::Int64Array, Value::null_array(), Value::array(vec![ Value::int64(1), @@ -132,11 +145,18 @@ test_type!(int8_array( ]), )); -test_type!(float4(postgresql, "float4", Value::null_float(), Value::float(1.234))); +test_type!(float4( + postgresql, + "float4", + ColumnType::Float, + Value::null_float(), + Value::float(1.234) +)); test_type!(float4_array( postgresql, "float4[]", + ColumnType::FloatArray, Value::null_array(), Value::array(vec![Value::float(1.1234), Value::float(4.321), Value::null_float()]) )); @@ -144,6 +164,7 @@ test_type!(float4_array( test_type!(float8( postgresql, "float8", + ColumnType::Double, Value::null_double(), Value::double(1.12345764), )); @@ -151,6 +172,7 @@ test_type!(float8( test_type!(float8_array( postgresql, "float8[]", + ColumnType::DoubleArray, Value::null_array(), Value::array(vec![Value::double(1.1234), Value::double(4.321), Value::null_double()]) )); @@ -160,6 +182,7 @@ test_type!(float8_array( test_type!(oid_with_i32( postgresql, "oid", + ColumnType::Int64, (Value::null_int32(), Value::null_int64()), (Value::int32(i32::MAX), Value::int64(i32::MAX)), (Value::int32(u32::MIN as i32), Value::int64(u32::MIN)), @@ -168,6 +191,7 @@ test_type!(oid_with_i32( test_type!(oid_with_i64( postgresql, "oid", + ColumnType::Int64, Value::null_int64(), Value::int64(u32::MAX), Value::int64(u32::MIN), @@ -176,6 +200,7 @@ test_type!(oid_with_i64( test_type!(oid_array( postgresql, "oid[]", + ColumnType::Int64Array, Value::null_array(), Value::array(vec![ Value::int64(1), @@ -188,6 +213,7 @@ test_type!(oid_array( test_type!(serial2( postgresql, "serial2", + ColumnType::Int32, Value::int32(i16::MIN), Value::int32(i16::MAX), )); @@ -195,6 +221,7 @@ test_type!(serial2( test_type!(serial4( postgresql, "serial4", + ColumnType::Int32, Value::int32(i32::MIN), Value::int32(i32::MAX), )); @@ -202,15 +229,23 @@ test_type!(serial4( test_type!(serial8( postgresql, "serial8", + ColumnType::Int64, Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(char(postgresql, "char(6)", Value::null_text(), Value::text("foobar"))); +test_type!(char( + postgresql, + "char(6)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(char_array( postgresql, "char(6)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); @@ -218,6 +253,7 @@ test_type!(char_array( test_type!(varchar( postgresql, "varchar(255)", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); @@ -225,24 +261,39 @@ test_type!(varchar( test_type!(varchar_array( postgresql, "varchar(255)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); -test_type!(text(postgresql, "text", Value::null_text(), Value::text("foobar"))); +test_type!(text( + postgresql, + "text", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(text_array( postgresql, "text[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); -test_type!(bit(postgresql, "bit(4)", Value::null_text(), Value::text("1001"))); +test_type!(bit( + postgresql, + "bit(4)", + ColumnType::Text, + Value::null_text(), + Value::text("1001") +)); test_type!(bit_array( postgresql, "bit(4)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("1001"), Value::text("0110"), Value::null_text()]) )); @@ -250,6 +301,7 @@ test_type!(bit_array( test_type!(varbit( postgresql, "varbit(20)", + ColumnType::Text, Value::null_text(), Value::text("001010101") )); @@ -257,6 +309,7 @@ test_type!(varbit( test_type!(varbit_array( postgresql, "varbit(20)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("001010101"), @@ -265,11 +318,18 @@ test_type!(varbit_array( ]) )); -test_type!(inet(postgresql, "inet", Value::null_text(), Value::text("127.0.0.1"))); +test_type!(inet( + postgresql, + "inet", + ColumnType::Text, + Value::null_text(), + Value::text("127.0.0.1") +)); test_type!(inet_array( postgresql, "inet[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("127.0.0.1"), @@ -281,6 +341,7 @@ test_type!(inet_array( test_type!(json( postgresql, "json", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"foo": "bar"})) )); @@ -288,6 +349,7 @@ test_type!(json( test_type!(json_array( postgresql, "json[]", + ColumnType::JsonArray, Value::null_array(), Value::array(vec![ Value::json(serde_json::json!({"foo": "bar"})), @@ -299,6 +361,7 @@ test_type!(json_array( test_type!(jsonb( postgresql, "jsonb", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"foo": "bar"})) )); @@ -306,6 +369,7 @@ test_type!(jsonb( test_type!(jsonb_array( postgresql, "jsonb[]", + ColumnType::JsonArray, Value::null_array(), Value::array(vec![ Value::json(serde_json::json!({"foo": "bar"})), @@ -314,11 +378,18 @@ test_type!(jsonb_array( ]) )); -test_type!(xml(postgresql, "xml", Value::null_xml(), Value::xml("1",))); +test_type!(xml( + postgresql, + "xml", + ColumnType::Xml, + Value::null_xml(), + Value::xml("1",) +)); test_type!(xml_array( postgresql, "xml[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("1"), @@ -330,6 +401,7 @@ test_type!(xml_array( test_type!(uuid( postgresql, "uuid", + ColumnType::Uuid, Value::null_uuid(), Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()) )); @@ -337,6 +409,7 @@ test_type!(uuid( test_type!(uuid_array( postgresql, "uuid[]", + ColumnType::UuidArray, Value::null_array(), Value::array(vec![ Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()), @@ -347,6 +420,7 @@ test_type!(uuid_array( test_type!(date( postgresql, "date", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) )); @@ -354,6 +428,7 @@ test_type!(date( test_type!(date_array( postgresql, "date[]", + ColumnType::DateArray, Value::null_array(), Value::array(vec![ Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()), @@ -364,6 +439,7 @@ test_type!(date_array( test_type!(time( postgresql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); @@ -371,6 +447,7 @@ test_type!(time( test_type!(time_array( postgresql, "time[]", + ColumnType::TimeArray, Value::null_array(), Value::array(vec![ Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()), @@ -378,37 +455,62 @@ test_type!(time_array( ]) )); -test_type!(timestamp(postgresql, "timestamp", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(timestamp( + postgresql, + "timestamp", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamp_array(postgresql, "timestamp[]", Value::null_array(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); +test_type!(timestamp_array( + postgresql, + "timestamp[]", + ColumnType::DateTimeArray, + Value::null_array(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::array(vec![ - Value::datetime(dt.with_timezone(&chrono::Utc)), - Value::null_datetime(), - ]) -})); + Value::array(vec![ + Value::datetime(dt.with_timezone(&chrono::Utc)), + Value::null_datetime(), + ]) + } +)); -test_type!(timestamptz(postgresql, "timestamptz", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(timestamptz( + postgresql, + "timestamptz", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamptz_array(postgresql, "timestamptz[]", Value::null_array(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); +test_type!(timestamptz_array( + postgresql, + "timestamptz[]", + ColumnType::DateTimeArray, + Value::null_array(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::array(vec![ - Value::datetime(dt.with_timezone(&chrono::Utc)), - Value::null_datetime(), - ]) -})); + Value::array(vec![ + Value::datetime(dt.with_timezone(&chrono::Utc)), + Value::null_datetime(), + ]) + } +)); test_type!(bytea( postgresql, "bytea", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()) )); @@ -416,6 +518,7 @@ test_type!(bytea( test_type!(bytea_array( postgresql, "bytea[]", + ColumnType::BytesArray, Value::null_array(), Value::array(vec![ Value::bytes(b"DEADBEEF".to_vec()), diff --git a/quaint/src/tests/types/postgres/bigdecimal.rs b/quaint/src/tests/types/postgres/bigdecimal.rs index 894b2c967629..1f8dcd663164 100644 --- a/quaint/src/tests/types/postgres/bigdecimal.rs +++ b/quaint/src/tests/types/postgres/bigdecimal.rs @@ -1,9 +1,11 @@ use super::*; use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; test_type!(decimal( postgresql, "decimal(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("3.14")?) )); @@ -11,6 +13,7 @@ test_type!(decimal( test_type!(decimal_10_2( postgresql, "decimal(10, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -20,6 +23,7 @@ test_type!(decimal_10_2( test_type!(decimal_35_6( postgresql, "decimal(35, 6)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950")?), Value::numeric(BigDecimal::from_str("3950.000000")?) @@ -101,6 +105,7 @@ test_type!(decimal_35_6( test_type!(decimal_35_2( postgresql, "decimal(35, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -114,12 +119,14 @@ test_type!(decimal_35_2( test_type!(decimal_4_0( postgresql, "decimal(4, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("3950")?) )); test_type!(decimal_65_30( postgresql, "decimal(65, 30)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("1.2")?), Value::numeric(BigDecimal::from_str("1.2000000000000000000000000000")?) @@ -133,6 +140,7 @@ test_type!(decimal_65_30( test_type!(decimal_65_34( postgresql, "decimal(65, 34)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3.1415926535897932384626433832795028")?), Value::numeric(BigDecimal::from_str("3.1415926535897932384626433832795028")?) @@ -150,12 +158,14 @@ test_type!(decimal_65_34( test_type!(decimal_35_0( postgresql, "decimal(35, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), )); test_type!(decimal_35_1( postgresql, "decimal(35, 1)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), Value::numeric(BigDecimal::from_str("79228162514264337593543950335.0")?) @@ -169,6 +179,7 @@ test_type!(decimal_35_1( test_type!(decimal_128_6( postgresql, "decimal(128, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str( "93431006223456789876545678909876545678903434369343100622345678987654567890987654567890343436999999100622345678343699999910.345678" )?), @@ -177,6 +188,7 @@ test_type!(decimal_128_6( test_type!(decimal_array( postgresql, "decimal(10,2)[]", + ColumnType::NumericArray, Value::null_array(), Value::array(vec![BigDecimal::from_str("3.14")?, BigDecimal::from_str("5.12")?]) )); @@ -184,6 +196,7 @@ test_type!(decimal_array( test_type!(money( postgresql, "money", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("1.12")?) )); @@ -191,6 +204,7 @@ test_type!(money( test_type!(money_array( postgresql, "money[]", + ColumnType::NumericArray, Value::null_array(), Value::array(vec![BigDecimal::from_str("1.12")?, BigDecimal::from_str("1.12")?]) )); @@ -198,6 +212,7 @@ test_type!(money_array( test_type!(float4( postgresql, "float4", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -208,6 +223,7 @@ test_type!(float4( test_type!(float8( postgresql, "float8", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(BigDecimal::from_str("1.123456")?), diff --git a/quaint/src/tests/types/sqlite.rs b/quaint/src/tests/types/sqlite.rs index c4950e748697..647f7217c83c 100644 --- a/quaint/src/tests/types/sqlite.rs +++ b/quaint/src/tests/types/sqlite.rs @@ -1,5 +1,7 @@ #![allow(clippy::approx_constant)] +use crate::connector::ColumnType; +use crate::macros::assert_matching_value_and_column_type; use crate::tests::test_api::sqlite_test_api; use crate::tests::test_api::TestApi; use crate::{ast::*, connector::Queryable}; @@ -9,6 +11,7 @@ use std::str::FromStr; test_type!(integer( sqlite, "INTEGER", + ColumnType::Int32, Value::null_int32(), Value::int32(i8::MIN), Value::int32(i8::MAX), @@ -21,17 +24,25 @@ test_type!(integer( test_type!(big_int( sqlite, "BIGINT", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(real(sqlite, "REAL", Value::null_double(), Value::double(1.12345))); +test_type!(real( + sqlite, + "REAL", + ColumnType::Double, + Value::null_double(), + Value::double(1.12345) +)); test_type!(float_decimal( sqlite, "FLOAT", - (Value::null_numeric(), Value::null_float()), + ColumnType::Double, + (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), Value::double(3.14) @@ -41,6 +52,7 @@ test_type!(float_decimal( test_type!(double_decimal( sqlite, "DOUBLE", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -48,27 +60,44 @@ test_type!(double_decimal( ) )); -test_type!(text(sqlite, "TEXT", Value::null_text(), Value::text("foobar huhuu"))); +test_type!(text( + sqlite, + "TEXT", + ColumnType::Text, + Value::null_text(), + Value::text("foobar huhuu") +)); test_type!(blob( sqlite, "BLOB", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()) )); -test_type!(float(sqlite, "FLOAT", Value::null_float(), Value::double(1.23))); +test_type!(float( + sqlite, + "FLOAT", + ColumnType::Double, + (Value::null_float(), Value::null_double()), + (Value::null_double(), Value::null_double()), + (Value::float(1.23456), Value::double(1.23456)), + (Value::double(1.2312313213), Value::double(1.2312313213)) +)); test_type!(double( sqlite, "DOUBLE", + ColumnType::Double, Value::null_double(), - Value::double(1.2312313213) + Value::double(1.2312313213), )); test_type!(boolean( sqlite, "BOOLEAN", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false) @@ -77,6 +106,7 @@ test_type!(boolean( test_type!(date( sqlite, "DATE", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(1984, 1, 1).unwrap()) )); @@ -84,6 +114,7 @@ test_type!(date( test_type!(datetime( sqlite, "DATETIME", + ColumnType::DateTime, Value::null_datetime(), Value::datetime(chrono::DateTime::from_str("2020-07-29T09:23:44.458Z").unwrap()) )); @@ -104,6 +135,7 @@ async fn test_type_text_datetime_rfc3339(api: &mut dyn TestApi) -> crate::Result let res = api.conn().select(select).await?.into_single()?; assert_eq!(Some(&Value::datetime(dt)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } @@ -125,7 +157,9 @@ async fn test_type_text_datetime_rfc2822(api: &mut dyn TestApi) -> crate::Result let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = api.conn().select(select).await?.into_single()?; + assert_eq!(ColumnType::DateTime, res.types[0]); assert_eq!(Some(&Value::datetime(dt)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } @@ -147,7 +181,9 @@ async fn test_type_text_datetime_custom(api: &mut dyn TestApi) -> crate::Result< let naive = chrono::NaiveDateTime::parse_from_str("2020-04-20 16:20:00", "%Y-%m-%d %H:%M:%S").unwrap(); let expected = chrono::DateTime::from_naive_utc_and_offset(naive, chrono::Utc); + assert_eq!(ColumnType::DateTime, res.types[0]); assert_eq!(Some(&Value::datetime(expected)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } diff --git a/quaint/src/tests/types/utils.rs b/quaint/src/tests/types/utils.rs new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs index 3aa1e9b0e2d2..e5644fc42d42 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs @@ -19,13 +19,13 @@ mod query_raw { let res = run_query_json!( &runner, r#" - mutation { - queryRaw( - query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test;\n END", - parameters: "[]" - ) - } - "# + mutation { + queryRaw( + query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test;\n END", + parameters: "[]" + ) + } + "# ); // fmt_execute_raw cannot run this query, doing it directly instead insta::assert_json_snapshot!(res, @@ -53,4 +53,34 @@ mod query_raw { Ok(()) } + + #[connector_test(only(MySQL("mariadb")))] + async fn mysql_call_2(runner: Runner) -> TestResult<()> { + let res = run_query_json!( + &runner, + r#" + mutation { + queryRaw( + query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test WHERE 1=0;\n END", + parameters: "[]" + ) + } + "# + ); + + insta::assert_json_snapshot!(res, + @r###" + { + "data": { + "queryRaw": { + "columns": [], + "types": [], + "rows": [] + } + } + } + "###); + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs index 6ebe42d0b089..21ca6e434079 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs @@ -176,7 +176,7 @@ mod transactional { let batch_results = runner.batch(queries, true, None).await?; insta::assert_snapshot!( batch_results.to_string(), - @r###"{"batchResult":[{"data":{"createOneModelB":{"id":1}}},{"data":{"executeRaw":1}},{"data":{"queryRaw":{"columns":["id"],"types":[],"rows":[]}}}]}"### + @r###"{"batchResult":[{"data":{"createOneModelB":{"id":1}}},{"data":{"executeRaw":1}},{"data":{"queryRaw":{"columns":["id"],"types":["int"],"rows":[]}}}]}"### ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs index 20b1e853b64d..8cd4c59af670 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs @@ -254,12 +254,12 @@ mod scalar_list { "types": [ "int", "string-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array" + "int-array", + "bigint-array", + "double-array", + "bytes-array", + "bool-array", + "datetime-array" ], "rows": [ [ @@ -332,13 +332,13 @@ mod scalar_list { ], "types": [ "int", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array" + "string-array", + "int-array", + "bigint-array", + "double-array", + "bytes-array", + "bool-array", + "datetime-array" ], "rows": [ [ diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs index fa7b8d64692d..980392054bfd 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs @@ -440,6 +440,40 @@ mod typed_output { Ok(()) } + #[connector_test(schema(generic), only(Mysql))] + async fn unknown_type_mysql(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!(&runner, fmt_query_raw(r#"SELECT POINT(1, 1);"#, vec![])), + @r###"{"data":{"queryRaw":{"columns":["POINT(1, 1)"],"types":["bytes"],"rows":[["AAAAAAEBAAAAAAAAAAAA8D8AAAAAAADwPw=="]]}}}"### + ); + + Ok(()) + } + + #[connector_test(schema(generic), only(Postgres))] + async fn unknown_type_pg(runner: Runner) -> TestResult<()> { + assert_error!( + &runner, + fmt_query_raw(r#"SELECT POINT(1, 1);"#, vec![]), + 2010, + "Failed to deserialize column of type 'point'" + ); + + Ok(()) + } + + #[connector_test(schema(generic), only(SqlServer))] + async fn unknown_type_mssql(runner: Runner) -> TestResult<()> { + assert_error!( + &runner, + fmt_query_raw(r#"SELECT geometry::Parse('POINT(3 4 7 2.5)');"#, vec![]), + 2010, + "not yet implemented for Udt" + ); + + Ok(()) + } + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { runner .query(format!("mutation {{ createOneTestModel(data: {data}) {{ id }} }}")) diff --git a/query-engine/connectors/sql-query-connector/src/ser_raw.rs b/query-engine/connectors/sql-query-connector/src/ser_raw.rs index 3c80b91d34ee..bbb23735704f 100644 --- a/query-engine/connectors/sql-query-connector/src/ser_raw.rs +++ b/query-engine/connectors/sql-query-connector/src/ser_raw.rs @@ -1,5 +1,5 @@ use quaint::{ - connector::{ResultRowRef, ResultSet}, + connector::{ColumnType, ResultRowRef, ResultSet}, Value, ValueType, }; use serde::{ser::*, Serialize, Serializer}; @@ -9,7 +9,7 @@ pub struct SerializedResultSet(pub ResultSet); #[derive(Debug, Serialize)] struct InnerSerializedResultSet<'a> { columns: SerializedColumns<'a>, - types: &'a SerializedTypes, + types: SerializedTypes<'a>, rows: SerializedRows<'a>, } @@ -22,7 +22,7 @@ impl serde::Serialize for SerializedResultSet { InnerSerializedResultSet { columns: SerializedColumns(this), - types: &SerializedTypes::new(this), + types: SerializedTypes(this), rows: SerializedRows(this), } .serialize(serializer) @@ -39,75 +39,65 @@ impl<'a> Serialize for SerializedColumns<'a> { { let this = &self.0; - if this.is_empty() { - return this.columns().serialize(serializer); - } - - let first_row = this.first().unwrap(); - - let mut seq = serializer.serialize_seq(Some(first_row.len()))?; + this.columns().serialize(serializer) + } +} - for (idx, _) in first_row.iter().enumerate() { - if let Some(column_name) = this.columns().get(idx) { - seq.serialize_element(column_name)?; - } else { - // `query_raw` does not return column names in `ResultSet` when a call to a stored procedure is done - // See https://github.com/prisma/prisma/issues/6173 - seq.serialize_element(&format!("f{idx}"))?; +#[derive(Debug)] +struct SerializedTypes<'a>(&'a ResultSet); + +impl<'a> SerializedTypes<'a> { + fn infer_unknown_column_types(&self) -> Vec { + let rows = self.0; + + let mut types = rows.types().to_owned(); + // Find all the unknown column types to avoid unnecessary iterations. + let unknown_indexes = rows + .types() + .iter() + .enumerate() + .filter_map(|(idx, ty)| match ty.is_unknown() { + true => Some(idx), + false => None, + }); + + for unknown_idx in unknown_indexes { + // While quaint already infers `ColumnType`s from the database, it can still have ColumnType::Unknown. + // In this case, we try to infer the types from the actual response data. + for row in self.0.iter() { + let current_type = types[unknown_idx]; + let inferred_type = ColumnType::from(&row[unknown_idx]); + + if current_type.is_unknown() && !inferred_type.is_unknown() { + types[unknown_idx] = inferred_type; + break; + } } } - seq.end() + if !self.0.is_empty() { + // Client doesn't know how to handle unknown types. + assert!(!types.contains(&ColumnType::Unknown)); + } + + types } } -#[derive(Debug, Serialize)] -#[serde(transparent)] -struct SerializedTypes(Vec); - -impl SerializedTypes { - fn new(rows: &ResultSet) -> Self { - if rows.is_empty() { - return Self(Vec::with_capacity(0)); - } +impl Serialize for SerializedTypes<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let types = self.infer_unknown_column_types(); - let row_len = rows.first().unwrap().len(); - let mut types = vec![SerializedValueType::Unknown; row_len]; - let mut types_found = 0; - - // This attempts to infer types based on `quaint::Value` present in the rows. - // We need to go through every row because because empty and null arrays don't encode their inner type. - // In the best case scenario, this loop stops at the first row. - // In the worst case scenario, it'll keep looping until it finds an array with a non-null value. - 'outer: for row in rows.iter() { - for (idx, value) in row.iter().enumerate() { - let current_type = types[idx]; - - if matches!( - current_type, - SerializedValueType::Unknown | SerializedValueType::UnknownArray - ) { - let inferred_type = SerializedValueType::infer_from(value); - - if inferred_type != SerializedValueType::Unknown && inferred_type != current_type { - types[idx] = inferred_type; - - if inferred_type != SerializedValueType::UnknownArray { - types_found += 1; - } - } - } + let mut seq = serializer.serialize_seq(Some(types.len()))?; - if types_found == row_len { - break 'outer; - } - } + for column_type in types { + seq.serialize_element(&column_type.to_string())?; } - // Client doesn't know how to handle unknown types. - assert!(!types.contains(&SerializedValueType::Unknown)); - - Self(types) + seq.end() } } @@ -200,361 +190,3 @@ impl<'a> Serialize for SerializedValue<'a> { } } } - -#[derive(Debug, Copy, Clone, PartialEq, Serialize)] -enum SerializedValueType { - #[serde(rename = "int")] - Int32, - #[serde(rename = "bigint")] - Int64, - #[serde(rename = "float")] - Float, - #[serde(rename = "double")] - Double, - #[serde(rename = "string")] - Text, - #[serde(rename = "enum")] - Enum, - #[serde(rename = "bytes")] - Bytes, - #[serde(rename = "bool")] - Boolean, - #[serde(rename = "char")] - Char, - #[serde(rename = "decimal")] - Numeric, - #[serde(rename = "json")] - Json, - #[serde(rename = "xml")] - Xml, - #[serde(rename = "uuid")] - Uuid, - #[serde(rename = "datetime")] - DateTime, - #[serde(rename = "date")] - Date, - #[serde(rename = "time")] - Time, - - #[serde(rename = "int-array")] - Int32Array, - #[serde(rename = "bigint-array")] - Int64Array, - #[serde(rename = "float-array")] - FloatArray, - #[serde(rename = "double-array")] - DoubleArray, - #[serde(rename = "string-array")] - TextArray, - #[serde(rename = "bytes-array")] - BytesArray, - #[serde(rename = "bool-array")] - BooleanArray, - #[serde(rename = "char-array")] - CharArray, - #[serde(rename = "decimal-array")] - NumericArray, - #[serde(rename = "json-array")] - JsonArray, - #[serde(rename = "xml-array")] - XmlArray, - #[serde(rename = "uuid-array")] - UuidArray, - #[serde(rename = "datetime-array")] - DateTimeArray, - #[serde(rename = "date-array")] - DateArray, - #[serde(rename = "time-array")] - TimeArray, - - #[serde(rename = "unknown-array")] - UnknownArray, - - #[serde(rename = "unknown")] - Unknown, -} - -impl SerializedValueType { - fn infer_from(value: &Value) -> SerializedValueType { - match &value.typed { - ValueType::Int32(_) => SerializedValueType::Int32, - ValueType::Int64(_) => SerializedValueType::Int64, - ValueType::Float(_) => SerializedValueType::Float, - ValueType::Double(_) => SerializedValueType::Double, - ValueType::Text(_) => SerializedValueType::Text, - ValueType::Enum(_, _) => SerializedValueType::Enum, - ValueType::EnumArray(_, _) => SerializedValueType::TextArray, - ValueType::Bytes(_) => SerializedValueType::Bytes, - ValueType::Boolean(_) => SerializedValueType::Boolean, - ValueType::Char(_) => SerializedValueType::Char, - ValueType::Numeric(_) => SerializedValueType::Numeric, - ValueType::Json(_) => SerializedValueType::Json, - ValueType::Xml(_) => SerializedValueType::Xml, - ValueType::Uuid(_) => SerializedValueType::Uuid, - ValueType::DateTime(_) => SerializedValueType::DateTime, - ValueType::Date(_) => SerializedValueType::Date, - ValueType::Time(_) => SerializedValueType::Time, - - ValueType::Array(Some(values)) => { - if values.is_empty() { - return SerializedValueType::UnknownArray; - } - - match &values[0].typed { - ValueType::Int32(_) => SerializedValueType::Int32Array, - ValueType::Int64(_) => SerializedValueType::Int64Array, - ValueType::Float(_) => SerializedValueType::FloatArray, - ValueType::Double(_) => SerializedValueType::DoubleArray, - ValueType::Text(_) => SerializedValueType::TextArray, - ValueType::Bytes(_) => SerializedValueType::BytesArray, - ValueType::Boolean(_) => SerializedValueType::BooleanArray, - ValueType::Char(_) => SerializedValueType::CharArray, - ValueType::Numeric(_) => SerializedValueType::NumericArray, - ValueType::Json(_) => SerializedValueType::JsonArray, - ValueType::Xml(_) => SerializedValueType::XmlArray, - ValueType::Uuid(_) => SerializedValueType::UuidArray, - ValueType::DateTime(_) => SerializedValueType::DateTimeArray, - ValueType::Date(_) => SerializedValueType::DateArray, - ValueType::Time(_) => SerializedValueType::TimeArray, - ValueType::Enum(_, _) => SerializedValueType::TextArray, - ValueType::Array(_) | ValueType::EnumArray(_, _) => { - unreachable!("Only PG supports scalar lists and tokio-postgres does not support 2d arrays") - } - } - } - ValueType::Array(None) => SerializedValueType::UnknownArray, - } - } -} - -#[cfg(test)] -mod tests { - use super::SerializedResultSet; - use bigdecimal::BigDecimal; - use chrono::{DateTime, Utc}; - use expect_test::expect; - use quaint::{ - ast::{EnumName, EnumVariant}, - connector::ResultSet, - Value, - }; - use std::str::FromStr; - - #[test] - fn serialize_result_set() { - let names = vec![ - "int32".to_string(), - "int64".to_string(), - "float".to_string(), - "double".to_string(), - "text".to_string(), - "enum".to_string(), - "bytes".to_string(), - "boolean".to_string(), - "char".to_string(), - "numeric".to_string(), - "json".to_string(), - "xml".to_string(), - "uuid".to_string(), - "datetime".to_string(), - "date".to_string(), - "time".to_string(), - "intArray".to_string(), - ]; - let rows = vec![vec![ - Value::int32(42), - Value::int64(42), - Value::float(42.523), - Value::double(42.523), - Value::text("heLlo"), - Value::enum_variant_with_name("Red", EnumName::new("Color", Option::::None)), - Value::bytes(b"hello".to_vec()), - Value::boolean(true), - Value::character('c'), - Value::numeric(BigDecimal::from_str("123456789.123456789").unwrap()), - Value::json(serde_json::json!({"hello": "world"})), - Value::xml("world"), - Value::uuid(uuid::Uuid::from_str("550e8400-e29b-41d4-a716-446655440000").unwrap()), - Value::datetime( - chrono::DateTime::parse_from_rfc3339("2021-01-01T02:00:00Z") - .map(DateTime::::from) - .unwrap(), - ), - Value::date(chrono::NaiveDate::from_ymd_opt(2021, 1, 1).unwrap()), - Value::time(chrono::NaiveTime::from_hms_opt(2, 0, 0).unwrap()), - Value::array(vec![Value::int32(42), Value::int32(42)]), - ]]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "int32", - "int64", - "float", - "double", - "text", - "enum", - "bytes", - "boolean", - "char", - "numeric", - "json", - "xml", - "uuid", - "datetime", - "date", - "time", - "intArray" - ], - "types": [ - "int", - "bigint", - "float", - "double", - "string", - "enum", - "bytes", - "bool", - "char", - "decimal", - "json", - "xml", - "uuid", - "datetime", - "date", - "time", - "int-array" - ], - "rows": [ - [ - 42, - "42", - 42.523, - 42.523, - "heLlo", - "Red", - "aGVsbG8=", - true, - "c", - "123456789.123456789", - { - "hello": "world" - }, - "world", - "550e8400-e29b-41d4-a716-446655440000", - "2021-01-01T02:00:00+00:00", - "2021-01-01", - "02:00:00", - [ - 42, - 42 - ] - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } - - #[test] - fn serialize_empty_result_set() { - let names = vec!["hello".to_string()]; - let result_set = ResultSet::new(names, vec![]); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "hello" - ], - "types": [], - "rows": [] - }"#]]; - - expected.assert_eq(&serialized) - } - - #[test] - fn serialize_arrays() { - let names = vec!["array".to_string()]; - let rows = vec![ - vec![Value::null_array()], - vec![Value::array(vec![Value::int32(42), Value::int64(42)])], - vec![Value::array(vec![Value::text("heLlo"), Value::null_text()])], - ]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "array" - ], - "types": [ - "int-array" - ], - "rows": [ - [ - null - ], - [ - [ - 42, - "42" - ] - ], - [ - [ - "heLlo", - null - ] - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } - - #[test] - fn serialize_enum_array() { - let names = vec!["array".to_string()]; - let rows = vec![ - vec![Value::enum_array_with_name( - vec![EnumVariant::new("A"), EnumVariant::new("B")], - EnumName::new("Alphabet", Some("foo")), - )], - vec![Value::null_enum_array()], - ]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "array" - ], - "types": [ - "string-array" - ], - "rows": [ - [ - [ - "A", - "B" - ] - ], - [ - null - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } -} diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index 8fb07d6f6230..0282248701ad 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -5,7 +5,7 @@ pub use crate::types::{ColumnType, JSResultSet}; use quaint::bigdecimal::BigDecimal; use quaint::chrono::{DateTime, NaiveDate, NaiveTime, Utc}; use quaint::{ - connector::ResultSet as QuaintResultSet, + connector::{ColumnType as QuaintColumnType, ResultSet as QuaintResultSet}, error::{Error as QuaintError, ErrorKind}, Value as QuaintValue, }; @@ -22,6 +22,7 @@ impl TryFrom for QuaintResultSet { } = js_result_set; let mut quaint_rows = Vec::with_capacity(rows.len()); + let quaint_column_types = column_types.iter().map(QuaintColumnType::from).collect::>(); for row in rows { let mut quaint_row = Vec::with_capacity(column_types.len()); @@ -37,7 +38,7 @@ impl TryFrom for QuaintResultSet { } let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); - let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); + let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_column_types, quaint_rows); // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), // but that's Quaint's ResultSet API and that's how the MySQL connector does it. diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 1b4cbe531359..83c69fbf146d 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -6,7 +6,7 @@ use std::str::FromStr; #[cfg(not(target_arch = "wasm32"))] use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; -use quaint::connector::{ExternalConnectionInfo, SqlFamily}; +use quaint::connector::{ColumnType as QuaintColumnType, ExternalConnectionInfo, SqlFamily}; #[cfg(target_arch = "wasm32")] use tsify::Tsify; @@ -126,129 +126,150 @@ impl JSResultSet { } } -#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi)] -#[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Deserialize_repr))] -#[repr(u8)] -#[derive(Debug)] -pub enum ColumnType { - // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] +macro_rules! js_column_type { + ($($(#[$($attrss:tt)*])*$name:ident($val:expr) => $quaint_name:ident,)*) => { + #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi)] + #[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Deserialize_repr))] + #[repr(u8)] + #[derive(Debug)] + pub enum ColumnType { + $( + $(#[$($attrss)*])* + $name = $val, + )* + } + + impl From<&ColumnType> for QuaintColumnType { + fn from(value: &ColumnType) -> Self { + match value { + $(ColumnType::$name => QuaintColumnType::$quaint_name,)* + } + } + } + }; +} + +// JsColumnType(discriminant) => quaint::ColumnType +js_column_type! { + /// [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] /// The following PlanetScale type IDs are mapped into Int32: /// - INT8 (TINYINT) -> e.g. `127` /// - INT16 (SMALLINT) -> e.g. `32767` /// - INT24 (MEDIUMINT) -> e.g. `8388607` /// - INT32 (INT) -> e.g. `2147483647` - Int32 = 0, + Int32(0) => Int32, /// The following PlanetScale type IDs are mapped into Int64: /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) - Int64 = 1, + Int64(1) => Int64, /// The following PlanetScale type IDs are mapped into Float: /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` - Float = 2, + Float(2) => Float, /// The following PlanetScale type IDs are mapped into Double: /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` - Double = 3, + Double(3) => Double, /// The following PlanetScale type IDs are mapped into Numeric: /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) - Numeric = 4, + Numeric(4) => Numeric, /// The following PlanetScale type IDs are mapped into Boolean: /// - BOOLEAN (BOOLEAN) -> e.g. `1` - Boolean = 5, + Boolean(5) => Boolean, - Character = 6, + + Character(6) => Char, /// The following PlanetScale type IDs are mapped into Text: /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) - Text = 7, + Text(7) => Text, /// The following PlanetScale type IDs are mapped into Date: /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) - Date = 8, + Date(8) => Date, /// The following PlanetScale type IDs are mapped into Time: /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) - Time = 9, + Time(9) => Time, + /// The following PlanetScale type IDs are mapped into DateTime: /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - DateTime = 10, + DateTime(10) => DateTime, /// The following PlanetScale type IDs are mapped into Json: /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) - Json = 11, + Json(11) => Json, /// The following PlanetScale type IDs are mapped into Enum: /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) - Enum = 12, + Enum(12) => Enum, + /// The following PlanetScale type IDs are mapped into Bytes: /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) - Bytes = 13, + Bytes(13) => Bytes, + /// The following PlanetScale type IDs are mapped into Set: /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) /// This is currently unhandled, and will panic if encountered. - Set = 14, + Set(14) => Text, /// UUID from postgres-flavored driver adapters is mapped to this type. - Uuid = 15, + Uuid(15) => Uuid, - /* - * Scalar arrays - */ /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) - Int32Array = 64, + Int32Array(64) => Int32Array, /// Int64 array (INT8_ARRAY in PostgreSQL) - Int64Array = 65, + Int64Array(65) => Int64Array, /// Float array (FLOAT4_ARRAY in PostgreSQL) - FloatArray = 66, + FloatArray(66) => FloatArray, /// Double array (FLOAT8_ARRAY in PostgreSQL) - DoubleArray = 67, + DoubleArray(67) => DoubleArray, /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) - NumericArray = 68, + NumericArray(68) => NumericArray, /// Boolean array (BOOL_ARRAY in PostgreSQL) - BooleanArray = 69, + BooleanArray(69) => BooleanArray, /// Char array (CHAR_ARRAY in PostgreSQL) - CharacterArray = 70, + CharacterArray(70) => CharArray, /// Text array (TEXT_ARRAY in PostgreSQL) - TextArray = 71, + TextArray(71) => TextArray, /// Date array (DATE_ARRAY in PostgreSQL) - DateArray = 72, + DateArray(72) => DateArray, /// Time array (TIME_ARRAY in PostgreSQL) - TimeArray = 73, + TimeArray(73) => TimeArray, /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) - DateTimeArray = 74, + DateTimeArray(74) => DateTimeArray, /// Json array (JSON_ARRAY in PostgreSQL) - JsonArray = 75, + JsonArray(75) => JsonArray, - /// Enum array - EnumArray = 76, + /// Enum array (ENUM_ARRAY in PostgreSQL) + EnumArray(76) => TextArray, /// Bytes array (BYTEA_ARRAY in PostgreSQL) - BytesArray = 77, + BytesArray(77) => BytesArray, /// Uuid array (UUID_ARRAY in PostgreSQL) - UuidArray = 78, + UuidArray(78) => UuidArray, /* * Below there are custom types that don't have a 1:1 translation with a quaint::Value. @@ -259,7 +280,7 @@ pub enum ColumnType { /// /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or /// COUNT, and it can be mapped to either Int64, or Double - UnknownNumber = 128, + UnknownNumber(128) => Unknown, } #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index fe9a66b449a5..e23d5927c555 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -75,7 +75,7 @@ all = [ graphql-protocol = ["query-core/graphql-protocol", "dep:graphql-parser"] [build-dependencies] -cfg_aliases = "0.2.0" +cfg_aliases = "0.2.1" [[bench]] name = "query_planning_bench" diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs index 959ed6de8632..d8ed39620911 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs @@ -2,7 +2,7 @@ pub(crate) use quaint::connector::rusqlite; -use quaint::connector::{GetRow, ToColumnNames}; +use quaint::connector::{ColumnType, GetRow, ToColumnNames}; use schema_connector::{ConnectorError, ConnectorResult}; use sql_schema_describer::{sqlite as describer, DescriberErrorKind, SqlSchema}; use std::sync::Mutex; @@ -56,6 +56,7 @@ impl Connection { let conn = self.0.lock().unwrap(); let mut stmt = conn.prepare_cached(sql).map_err(convert_error)?; + let column_types = stmt.columns().iter().map(ColumnType::from).collect::>(); let mut rows = stmt .query(rusqlite::params_from_iter(params.iter())) .map_err(convert_error)?; @@ -65,7 +66,11 @@ impl Connection { converted_rows.push(row.get_result_row().unwrap()); } - Ok(quaint::prelude::ResultSet::new(column_names, converted_rows)) + Ok(quaint::prelude::ResultSet::new( + column_names, + column_types, + converted_rows, + )) } } diff --git a/schema-engine/sql-schema-describer/src/sqlite.rs b/schema-engine/sql-schema-describer/src/sqlite.rs index 51f75a90343a..bd82c52fce0e 100644 --- a/schema-engine/sql-schema-describer/src/sqlite.rs +++ b/schema-engine/sql-schema-describer/src/sqlite.rs @@ -9,7 +9,7 @@ use either::Either; use indexmap::IndexMap; use quaint::{ ast::{Value, ValueType}, - connector::{GetRow, ToColumnNames}, + connector::{ColumnType as QuaintColumnType, GetRow, ToColumnNames}, prelude::ResultRow, }; use std::{any::type_name, borrow::Cow, collections::BTreeMap, convert::TryInto, fmt::Debug, path::Path}; @@ -33,6 +33,7 @@ impl Connection for std::sync::Mutex { ) -> quaint::Result { let conn = self.lock().unwrap(); let mut stmt = conn.prepare_cached(sql)?; + let column_types = stmt.columns().iter().map(QuaintColumnType::from).collect::>(); let mut rows = stmt.query(quaint::connector::rusqlite::params_from_iter(params.iter()))?; let column_names = rows.to_column_names(); let mut converted_rows = Vec::new(); @@ -40,7 +41,11 @@ impl Connection for std::sync::Mutex { converted_rows.push(row.get_result_row().unwrap()); } - Ok(quaint::prelude::ResultSet::new(column_names, converted_rows)) + Ok(quaint::prelude::ResultSet::new( + column_names, + column_types, + converted_rows, + )) } }