diff --git a/Cargo.lock b/Cargo.lock index 3097535d319d..1e9a74719224 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3497,6 +3497,7 @@ dependencies = [ "diagnostics", "either", "enumflags2", + "hex", "indoc 2.0.3", "itertools 0.12.0", "lsp-types", @@ -5006,7 +5007,6 @@ dependencies = [ "chrono", "cuid", "futures", - "hex", "itertools 0.12.0", "once_cell", "opentelemetry", diff --git a/psl/psl-core/Cargo.toml b/psl/psl-core/Cargo.toml index 97f4dd56d470..64343301c2c7 100644 --- a/psl/psl-core/Cargo.toml +++ b/psl/psl-core/Cargo.toml @@ -20,6 +20,7 @@ serde_json.workspace = true enumflags2 = "0.7" indoc.workspace = true either = "1.8.1" +hex = "0.4" # For the connector API. lsp-types = "0.91.1" diff --git a/psl/psl-core/src/builtin_connectors/cockroach_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/cockroach_datamodel_connector.rs index 69d1f0467b91..03b312ba3574 100644 --- a/psl/psl-core/src/builtin_connectors/cockroach_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/cockroach_datamodel_connector.rs @@ -60,9 +60,9 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector InsertReturning | UpdateReturning | RowIn | - LateralJoin | DeleteReturning | - SupportsFiltersOnRelationsWithoutJoins + SupportsFiltersOnRelationsWithoutJoins | + LateralJoin }); const SCALAR_TYPE_DEFAULTS: &[(ScalarType, CockroachType)] = &[ @@ -143,7 +143,7 @@ impl Connector for CockroachDatamodelConnector { } } - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance { + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option { let native_type = SCALAR_TYPE_DEFAULTS .iter() .find(|(st, _)| st == scalar_type) @@ -151,7 +151,7 @@ impl Connector for CockroachDatamodelConnector { .ok_or_else(|| format!("Could not find scalar type {scalar_type:?} in SCALAR_TYPE_DEFAULTS")) .unwrap(); - NativeTypeInstance::new::(*native_type) + Some(NativeTypeInstance::new::(*native_type)) } fn native_type_is_default_for_scalar_type( @@ -320,17 +320,31 @@ impl Connector for CockroachDatamodelConnector { match native_type { Some(ct) => match ct { - CockroachType::Timestamptz(_) => super::utils::parse_timestamptz(str), - CockroachType::Timestamp(_) => super::utils::parse_timestamp(str), - CockroachType::Date => super::utils::parse_date(str), - CockroachType::Time(_) => super::utils::parse_time(str), - CockroachType::Timetz(_) => super::utils::parse_timetz(str), + CockroachType::Timestamptz(_) => super::utils::postgres::parse_timestamptz(str), + CockroachType::Timestamp(_) => super::utils::postgres::parse_timestamp(str), + CockroachType::Date => super::utils::common::parse_date(str), + CockroachType::Time(_) => super::utils::common::parse_time(str), + CockroachType::Timetz(_) => super::utils::postgres::parse_timetz(str), _ => unreachable!(), }, - None => self.parse_json_datetime( - str, - Some(self.default_native_type_for_scalar_type(&ScalarType::DateTime)), - ), + None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)), + } + } + + fn parse_json_bytes(&self, str: &str, nt: Option) -> prisma_value::PrismaValueResult> { + let native_type: Option<&CockroachType> = nt.as_ref().map(|nt| nt.downcast_ref()); + + match native_type { + Some(ct) => match ct { + CockroachType::Bytes => { + super::utils::postgres::parse_bytes(str).map_err(|_| prisma_value::ConversionFailure { + from: "hex".into(), + to: "bytes".into(), + }) + } + _ => unreachable!(), + }, + None => self.parse_json_bytes(str, self.default_native_type_for_scalar_type(&ScalarType::Bytes)), } } } diff --git a/psl/psl-core/src/builtin_connectors/mongodb.rs b/psl/psl-core/src/builtin_connectors/mongodb.rs index da111a28434f..814f3f60fd48 100644 --- a/psl/psl-core/src/builtin_connectors/mongodb.rs +++ b/psl/psl-core/src/builtin_connectors/mongodb.rs @@ -93,9 +93,10 @@ impl Connector for MongoDbDatamodelConnector { mongodb_types::CONSTRUCTORS } - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance { + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option { let native_type = default_for(scalar_type); - NativeTypeInstance::new::(*native_type) + + Some(NativeTypeInstance::new::(*native_type)) } fn native_type_is_default_for_scalar_type( diff --git a/psl/psl-core/src/builtin_connectors/mongodb/mongodb_types.rs b/psl/psl-core/src/builtin_connectors/mongodb/mongodb_types.rs index 501f0bc5f268..7b37a03f6f7b 100644 --- a/psl/psl-core/src/builtin_connectors/mongodb/mongodb_types.rs +++ b/psl/psl-core/src/builtin_connectors/mongodb/mongodb_types.rs @@ -39,7 +39,7 @@ static DEFAULT_MAPPING: Lazy> = Lazy::new(|| { (ScalarType::Float, MongoDbType::Double), (ScalarType::Boolean, MongoDbType::Bool), (ScalarType::String, MongoDbType::String), - (ScalarType::DateTime, MongoDbType::Timestamp), + (ScalarType::DateTime, MongoDbType::Date), (ScalarType::Bytes, MongoDbType::BinData), (ScalarType::Json, MongoDbType::Json), ] diff --git a/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs index f0ef956a82bd..2146e2b95a1d 100644 --- a/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs @@ -141,14 +141,15 @@ impl Connector for MsSqlDatamodelConnector { } } - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance { + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option { let nt = SCALAR_TYPE_DEFAULTS .iter() .find(|(st, _)| st == scalar_type) .map(|(_, native_type)| native_type) .ok_or_else(|| format!("Could not find scalar type {scalar_type:?} in SCALAR_TYPE_DEFAULTS")) .unwrap(); - NativeTypeInstance::new::(*nt) + + Some(NativeTypeInstance::new::(*nt)) } fn native_type_is_default_for_scalar_type( diff --git a/psl/psl-core/src/builtin_connectors/mysql_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/mysql_datamodel_connector.rs index 45b9adf27c35..a44a2639e430 100644 --- a/psl/psl-core/src/builtin_connectors/mysql_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/mysql_datamodel_connector.rs @@ -1,7 +1,9 @@ mod native_types; mod validations; +use chrono::FixedOffset; pub use native_types::MySqlType; +use prisma_value::{decode_bytes, PrismaValueResult}; use super::completions; use crate::{ @@ -64,7 +66,8 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector SupportsTxIsolationRepeatableRead | SupportsTxIsolationSerializable | RowIn | - SupportsFiltersOnRelationsWithoutJoins + SupportsFiltersOnRelationsWithoutJoins | + CorrelatedSubqueries }); const CONSTRAINT_SCOPES: &[ConstraintScope] = &[ConstraintScope::GlobalForeignKey, ConstraintScope::ModelKeyIndex]; @@ -160,7 +163,7 @@ impl Connector for MySqlDatamodelConnector { } } - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance { + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option { let native_type = SCALAR_TYPE_DEFAULTS .iter() .find(|(st, _)| st == scalar_type) @@ -168,7 +171,7 @@ impl Connector for MySqlDatamodelConnector { .ok_or_else(|| format!("Could not find scalar type {scalar_type:?} in SCALAR_TYPE_DEFAULTS")) .unwrap(); - NativeTypeInstance::new::(*native_type) + Some(NativeTypeInstance::new::(*native_type)) } fn native_type_is_default_for_scalar_type( @@ -289,4 +292,28 @@ impl Connector for MySqlDatamodelConnector { fn flavour(&self) -> Flavour { Flavour::Mysql } + + fn parse_json_datetime( + &self, + str: &str, + nt: Option, + ) -> chrono::ParseResult> { + let native_type: Option<&MySqlType> = nt.as_ref().map(|nt| nt.downcast_ref()); + + match native_type { + Some(pt) => match pt { + Date => super::utils::common::parse_date(str), + Time(_) => super::utils::common::parse_time(str), + DateTime(_) => super::utils::mysql::parse_datetime(str), + Timestamp(_) => super::utils::mysql::parse_timestamp(str), + _ => unreachable!(), + }, + None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)), + } + } + + // On MySQL, bytes are encoded as base64 in the database directly. + fn parse_json_bytes(&self, str: &str, _nt: Option) -> PrismaValueResult> { + decode_bytes(str) + } } diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index 45963c1c58dd..4da85fa89861 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -68,9 +68,9 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector UpdateReturning | RowIn | DistinctOn | - LateralJoin | DeleteReturning | - SupportsFiltersOnRelationsWithoutJoins + SupportsFiltersOnRelationsWithoutJoins | + LateralJoin }); pub struct PostgresDatamodelConnector; @@ -331,7 +331,7 @@ impl Connector for PostgresDatamodelConnector { } } - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance { + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option { let native_type = SCALAR_TYPE_DEFAULTS .iter() .find(|(st, _)| st == scalar_type) @@ -339,7 +339,7 @@ impl Connector for PostgresDatamodelConnector { .ok_or_else(|| format!("Could not find scalar type {scalar_type:?} in SCALAR_TYPE_DEFAULTS")) .unwrap(); - NativeTypeInstance::new::(*native_type) + Some(NativeTypeInstance::new::(*native_type)) } fn native_type_is_default_for_scalar_type( @@ -580,17 +580,31 @@ impl Connector for PostgresDatamodelConnector { match native_type { Some(pt) => match pt { - Timestamptz(_) => super::utils::parse_timestamptz(str), - Timestamp(_) => super::utils::parse_timestamp(str), - Date => super::utils::parse_date(str), - Time(_) => super::utils::parse_time(str), - Timetz(_) => super::utils::parse_timetz(str), + Timestamptz(_) => super::utils::postgres::parse_timestamptz(str), + Timestamp(_) => super::utils::postgres::parse_timestamp(str), + Date => super::utils::common::parse_date(str), + Time(_) => super::utils::common::parse_time(str), + Timetz(_) => super::utils::postgres::parse_timetz(str), _ => unreachable!(), }, - None => self.parse_json_datetime( - str, - Some(self.default_native_type_for_scalar_type(&ScalarType::DateTime)), - ), + None => self.parse_json_datetime(str, self.default_native_type_for_scalar_type(&ScalarType::DateTime)), + } + } + + fn parse_json_bytes(&self, str: &str, nt: Option) -> prisma_value::PrismaValueResult> { + let native_type: Option<&PostgresType> = nt.as_ref().map(|nt| nt.downcast_ref()); + + match native_type { + Some(ct) => match ct { + PostgresType::ByteA => { + super::utils::postgres::parse_bytes(str).map_err(|_| prisma_value::ConversionFailure { + from: "hex".into(), + to: "bytes".into(), + }) + } + _ => unreachable!(), + }, + None => self.parse_json_bytes(str, self.default_native_type_for_scalar_type(&ScalarType::Bytes)), } } } diff --git a/psl/psl-core/src/builtin_connectors/sqlite_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/sqlite_datamodel_connector.rs index b8e6c69b8fb0..8c0756b97cc0 100644 --- a/psl/psl-core/src/builtin_connectors/sqlite_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/sqlite_datamodel_connector.rs @@ -66,8 +66,8 @@ impl Connector for SqliteDatamodelConnector { unreachable!("No native types on Sqlite"); } - fn default_native_type_for_scalar_type(&self, _scalar_type: &ScalarType) -> NativeTypeInstance { - NativeTypeInstance::new(()) + fn default_native_type_for_scalar_type(&self, _scalar_type: &ScalarType) -> Option { + None } fn native_type_is_default_for_scalar_type( diff --git a/psl/psl-core/src/builtin_connectors/utils.rs b/psl/psl-core/src/builtin_connectors/utils.rs index 3ef9f55cd80a..a8d5618f5d23 100644 --- a/psl/psl-core/src/builtin_connectors/utils.rs +++ b/psl/psl-core/src/builtin_connectors/utils.rs @@ -1,37 +1,63 @@ -use chrono::*; +pub(crate) mod common { + use chrono::*; -pub(crate) fn parse_date(str: &str) -> Result, chrono::ParseError> { - chrono::NaiveDate::parse_from_str(str, "%Y-%m-%d") - .map(|date| DateTime::::from_utc(date.and_hms_opt(0, 0, 0).unwrap(), Utc)) - .map(DateTime::::from) -} + pub(crate) fn parse_date(str: &str) -> Result, chrono::ParseError> { + chrono::NaiveDate::parse_from_str(str, "%Y-%m-%d") + .map(|date| DateTime::::from_utc(date.and_hms_opt(0, 0, 0).unwrap(), Utc)) + .map(DateTime::::from) + } -pub(crate) fn parse_timestamptz(str: &str) -> Result, chrono::ParseError> { - DateTime::parse_from_rfc3339(str) -} + pub(crate) fn parse_time(str: &str) -> Result, chrono::ParseError> { + chrono::NaiveTime::parse_from_str(str, "%H:%M:%S%.f") + .map(|time| { + let base_date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); -pub(crate) fn parse_timestamp(str: &str) -> Result, chrono::ParseError> { - NaiveDateTime::parse_from_str(str, "%Y-%m-%dT%H:%M:%S%.f") - .map(|dt| DateTime::from_utc(dt, Utc)) - .or_else(|_| DateTime::parse_from_rfc3339(str).map(DateTime::::from)) - .map(DateTime::::from) + DateTime::::from_utc(base_date.and_time(time), Utc) + }) + .map(DateTime::::from) + } + + pub(crate) fn parse_timestamp(str: &str, fmt: &str) -> Result, chrono::ParseError> { + NaiveDateTime::parse_from_str(str, fmt) + .map(|dt| DateTime::from_utc(dt, Utc)) + .or_else(|_| DateTime::parse_from_rfc3339(str).map(DateTime::::from)) + .map(DateTime::::from) + } } -pub(crate) fn parse_time(str: &str) -> Result, chrono::ParseError> { - chrono::NaiveTime::parse_from_str(str, "%H:%M:%S%.f") - .map(|time| { - let base_date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); +pub(crate) mod postgres { + use chrono::*; + + pub(crate) fn parse_timestamptz(str: &str) -> Result, chrono::ParseError> { + DateTime::parse_from_rfc3339(str) + } - DateTime::::from_utc(base_date.and_time(time), Utc) - }) - .map(DateTime::::from) + pub(crate) fn parse_timestamp(str: &str) -> Result, chrono::ParseError> { + super::common::parse_timestamp(str, "%Y-%m-%dT%H:%M:%S%.f") + } + + pub(crate) fn parse_timetz(str: &str) -> Result, chrono::ParseError> { + // We currently don't support time with timezone. + // We strip the timezone information and parse it as a time. + // This is inline with what Quaint does already. + let time_without_tz = str.split('+').next().unwrap(); + + super::common::parse_time(time_without_tz) + } + + pub(crate) fn parse_bytes(str: &str) -> Result, hex::FromHexError> { + hex::decode(&str[2..]) + } } -pub(crate) fn parse_timetz(str: &str) -> Result, chrono::ParseError> { - // We currently don't support time with timezone. - // We strip the timezone information and parse it as a time. - // This is inline with what Quaint does already. - let time_without_tz = str.split('+').next().unwrap(); +pub(crate) mod mysql { + use chrono::*; + + pub(crate) fn parse_datetime(str: &str) -> Result, chrono::ParseError> { + super::common::parse_timestamp(str, "%Y-%m-%d %H:%M:%S%.f") + } - parse_time(time_without_tz) + pub(crate) fn parse_timestamp(str: &str) -> Result, chrono::ParseError> { + parse_datetime(str) + } } diff --git a/psl/psl-core/src/datamodel_connector.rs b/psl/psl-core/src/datamodel_connector.rs index 41e5708dc2b0..107abd24710f 100644 --- a/psl/psl-core/src/datamodel_connector.rs +++ b/psl/psl-core/src/datamodel_connector.rs @@ -195,7 +195,7 @@ pub trait Connector: Send + Sync { /// On each connector, each built-in Prisma scalar type (`Boolean`, /// `String`, `Float`, etc.) has a corresponding native type. - fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> NativeTypeInstance; + fn default_native_type_for_scalar_type(&self, scalar_type: &ScalarType) -> Option; /// Same mapping as `default_native_type_for_scalar_type()`, but in the opposite direction. fn native_type_is_default_for_scalar_type( @@ -321,6 +321,14 @@ pub trait Connector: Send + Sync { ) -> chrono::ParseResult> { unreachable!("This method is only implemented on connectors with lateral join support.") } + + fn parse_json_bytes( + &self, + _str: &str, + _nt: Option, + ) -> prisma_value::PrismaValueResult> { + unreachable!("This method is only implemented on connectors with lateral join support.") + } } #[derive(Copy, Clone, Debug, PartialEq)] diff --git a/psl/psl-core/src/datamodel_connector/capabilities.rs b/psl/psl-core/src/datamodel_connector/capabilities.rs index 52a524397b7a..9b3fe025d8ca 100644 --- a/psl/psl-core/src/datamodel_connector/capabilities.rs +++ b/psl/psl-core/src/datamodel_connector/capabilities.rs @@ -103,11 +103,12 @@ capabilities!( NativeUpsert, InsertReturning, UpdateReturning, - RowIn, // Connector supports (a, b) IN (c, d) expression. - DistinctOn, // Connector supports DB-level distinct (e.g. postgres) - LateralJoin, - DeleteReturning, // Connector supports deleting records and returning them in one operation. + RowIn, // Connector supports (a, b) IN (c, d) expression. + DistinctOn, // Connector supports DB-level distinct (e.g. postgres) + DeleteReturning, // Connector supports deleting records and returning them in one operation. SupportsFiltersOnRelationsWithoutJoins, // Connector supports rendering filters on relation fields without joins. + LateralJoin, // Connector supports lateral joins to resolve relations. + CorrelatedSubqueries, // Connector supports correlated subqueries to resolve relations. ); /// Contains all capabilities that the connector is able to serve. diff --git a/psl/psl-core/src/datamodel_connector/empty_connector.rs b/psl/psl-core/src/datamodel_connector/empty_connector.rs index 7ac7879c08f4..7c917ea9d08a 100644 --- a/psl/psl-core/src/datamodel_connector/empty_connector.rs +++ b/psl/psl-core/src/datamodel_connector/empty_connector.rs @@ -41,8 +41,8 @@ impl Connector for EmptyDatamodelConnector { ScalarType::String } - fn default_native_type_for_scalar_type(&self, _scalar_type: &ScalarType) -> NativeTypeInstance { - unreachable!() + fn default_native_type_for_scalar_type(&self, _scalar_type: &ScalarType) -> Option { + None } fn native_type_is_default_for_scalar_type( diff --git a/quaint/src/visitor/mysql.rs b/quaint/src/visitor/mysql.rs index 690079d65ea5..a406000cd7c0 100644 --- a/quaint/src/visitor/mysql.rs +++ b/quaint/src/visitor/mysql.rs @@ -90,6 +90,36 @@ impl<'a> Mysql<'a> { Ok(()) } + + fn visit_json_build_obj_expr(&mut self, expr: Expression<'a>) -> crate::Result<()> { + match expr.kind() { + // Convert bytes data to base64 + ExpressionKind::Column(col) => match (col.type_family.as_ref(), col.native_type.as_deref()) { + ( + Some(TypeFamily::Text(_)), + Some("LONGBLOB") | Some("BLOB") | Some("MEDIUMBLOB") | Some("SMALLBLOB") | Some("TINYBLOB") + | Some("VARBINARY") | Some("BINARY") | Some("BIT"), + ) => { + self.write("to_base64")?; + self.surround_with("(", ")", |s| s.visit_expression(expr))?; + + Ok(()) + } + // Convert floats to string to avoid losing precision + (_, Some("FLOAT")) => { + self.write("CONVERT")?; + self.surround_with("(", ")", |s| { + s.visit_expression(expr)?; + s.write(", ")?; + s.write("CHAR") + })?; + Ok(()) + } + _ => self.visit_expression(expr), + }, + _ => self.visit_expression(expr), + } + } } impl<'a> Visitor<'a> for Mysql<'a> { @@ -562,14 +592,34 @@ impl<'a> Visitor<'a> for Mysql<'a> { Ok(()) } - #[cfg(feature = "postgresql")] - fn visit_json_array_agg(&mut self, _array_agg: JsonArrayAgg<'a>) -> visitor::Result { - unimplemented!("JSON_ARRAYAGG is not yet supported on MySQL") + #[cfg(feature = "mysql")] + fn visit_json_array_agg(&mut self, array_agg: JsonArrayAgg<'a>) -> visitor::Result { + self.write("JSON_ARRAYAGG")?; + self.surround_with("(", ")", |s| s.visit_expression(*array_agg.expr))?; + + Ok(()) } - #[cfg(feature = "postgresql")] - fn visit_json_build_object(&mut self, _build_obj: JsonBuildObject<'a>) -> visitor::Result { - unimplemented!("JSON_OBJECT is not yet supported on MySQL") + #[cfg(feature = "mysql")] + fn visit_json_build_object(&mut self, build_obj: JsonBuildObject<'a>) -> visitor::Result { + let len = build_obj.exprs.len(); + + self.write("JSON_OBJECT")?; + self.surround_with("(", ")", |s| { + for (i, (name, expr)) in build_obj.exprs.into_iter().enumerate() { + s.visit_raw_value(Value::text(name))?; + s.write(", ")?; + s.visit_json_build_obj_expr(expr)?; + + if i < (len - 1) { + s.write(", ")?; + } + } + + Ok(()) + })?; + + Ok(()) } fn visit_ordering(&mut self, ordering: Ordering<'a>) -> visitor::Result { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/relation_load_strategy.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/relation_load_strategy.rs index 71f1256b0426..a44a2c2c9308 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/relation_load_strategy.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/relation_load_strategy.rs @@ -86,7 +86,9 @@ mod relation_load_strategy { async fn assert_used_lateral_join(runner: &mut Runner, expected: bool) { let logs = runner.get_logs().await; - let actual = logs.iter().any(|l| l.contains("LEFT JOIN LATERAL")); + let actual = logs + .iter() + .any(|l| l.contains("LEFT JOIN LATERAL") || (l.contains("JSON_ARRAYAGG") && l.contains("JSON_OBJECT"))); assert_eq!( actual, expected, @@ -123,8 +125,21 @@ mod relation_load_strategy { macro_rules! relation_load_strategy_tests_pair { ($name:ident, $query:expr, $result:literal) => { - relation_load_strategy_test!($name, join, $query, $result, only(Postgres, CockroachDb)); - relation_load_strategy_test!($name, query, $query, $result); + relation_load_strategy_test!( + $name, + join, + $query, + $result, + only(Postgres, CockroachDb, Mysql(8)) + ); + // TODO: Remove Mysql & Vitess exclusions once we are able to have version speficic preview features. + relation_load_strategy_test!( + $name, + query, + $query, + $result, + exclude(Mysql("5.6", "5.7", "mariadb")) + ); }; } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mod.rs index 70faf80832c5..933f4eb62968 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mod.rs @@ -1 +1,2 @@ +mod mysql; mod postgres; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mysql.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mysql.rs new file mode 100644 index 000000000000..d958b956e7d9 --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/native/mysql.rs @@ -0,0 +1,284 @@ +use indoc::indoc; +use query_engine_tests::*; + +#[test_suite(only(Mysql("8")))] +mod datetime { + fn schema_date() -> String { + let schema = indoc! { + r#"model Parent { + #id(id, Int, @id) + + childId Int? @unique + child Child? @relation(fields: [childId], references: [id]) + } + + model Child { + #id(id, Int, @id) + date DateTime @test.Date + date_2 DateTime @test.Date + time DateTime @test.Time(3) + time_2 DateTime @test.Time(3) + ts DateTime @test.Timestamp(3) + ts_2 DateTime @test.Timestamp(3) + dt DateTime @test.DateTime(3) + dt_2 DateTime @test.DateTime(3) + year Int @test.Year + + parent Parent? + }"# + }; + + schema.to_owned() + } + + #[connector_test(schema(schema_date))] + async fn dt_native(runner: Runner) -> TestResult<()> { + create_row( + &runner, + r#"{ + id: 1, + child: { create: { + id: 1, + date: "2016-09-24T00:00:00.000Z" + date_2: "2016-09-24T00:00:00.000+03:00" + time: "1111-11-11T13:02:20.321Z" + time_2: "1111-11-11T13:02:20.321+03:00" + ts: "2016-09-24T14:01:30.213Z" + ts_2: "2016-09-24T14:01:30.213+03:00" + dt: "2016-09-24T14:01:30.213Z" + dt_2: "2016-09-24T14:01:30.213+03:00", + year: 2023 + }} + }"#, + ) + .await?; + + insta::assert_snapshot!( + run_query!(runner, r#"{ findManyParent { id child { date date_2 time time_2 ts ts_2 dt dt_2 year } } }"#), + @r###"{"data":{"findManyParent":[{"id":1,"child":{"date":"2016-09-24T00:00:00.000Z","date_2":"2016-09-23T00:00:00.000Z","time":"1970-01-01T13:02:20.321Z","time_2":"1970-01-01T10:02:20.321Z","ts":"2016-09-24T14:01:30.213Z","ts_2":"2016-09-24T11:01:30.213Z","dt":"2016-09-24T14:01:30.213Z","dt_2":"2016-09-24T11:01:30.213Z","year":2023}}]}}"### + ); + + Ok(()) + } + + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { + runner + .query(format!("mutation {{ createOneParent(data: {}) {{ id }} }}", data)) + .await? + .assert_success(); + Ok(()) + } +} + +#[test_suite(only(Mysql("8")))] +mod decimal { + fn schema_decimal() -> String { + let schema = indoc! { + r#" + model Parent { + #id(id, Int, @id) + + childId Int? @unique + child Child? @relation(fields: [childId], references: [id]) + } + + model Child { + #id(id, Int, @id) + + float Float @test.Float + dfloat Float @test.Double + decFloat Decimal @test.Decimal(2, 1) + + parent Parent? + }"# + }; + + schema.to_owned() + } + + // "Postgres native decimal types" should "work" + #[connector_test(schema(schema_decimal))] + async fn native_decimal_types(runner: Runner) -> TestResult<()> { + create_row( + &runner, + r#"{ + id: 1, + child: { create: { + id: 1, + float: 1.1 + dfloat: 2.2 + decFloat: 3.1234 + }} + }"#, + ) + .await?; + + insta::assert_snapshot!( + run_query!(&runner, r#"{ findManyParent { id child { float dfloat decFloat } } }"#), + @r###"{"data":{"findManyParent":[{"id":1,"child":{"float":1.1,"dfloat":2.2,"decFloat":"3.1"}}]}}"### + ); + + Ok(()) + } + + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { + runner + .query(format!("mutation {{ createOneParent(data: {}) {{ id }} }}", data)) + .await? + .assert_success(); + Ok(()) + } +} + +#[test_suite(only(Mysql("8")))] +mod string { + fn schema_string() -> String { + let schema = indoc! { + r#" + model Parent { + #id(id, Int, @id) + + childId Int? @unique + child Child? @relation(fields: [childId], references: [id]) + } + + model Child { + #id(id, Int, @id) + char String @test.Char(10) + vChar String @test.VarChar(11) + tText String @test.TinyText + text String @test.Text + mText String @test.MediumText + ltext String @test.LongText + + parent Parent? + }"# + }; + + schema.to_owned() + } + + // "Mysql native string types" should "work" + #[connector_test(schema(schema_string))] + async fn native_string(runner: Runner) -> TestResult<()> { + create_row( + &runner, + r#"{ + id: 1, + child: { create: { + id: 1, + char: "1234567890" + vChar: "12345678910" + tText: "tiny text" + text: "text" + mText: "medium text" + ltext: "long text" + }} + }"#, + ) + .await?; + + insta::assert_snapshot!( + run_query!(&runner, r#"{ findManyParent { + id + child { + char + vChar + tText + text + mText + ltext + } + }}"#), + @r###"{"data":{"findManyParent":[{"id":1,"child":{"char":"1234567890","vChar":"12345678910","tText":"tiny text","text":"text","mText":"medium text","ltext":"long text"}}]}}"### + ); + + Ok(()) + } + + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { + runner + .query(format!("mutation {{ createOneParent(data: {}) {{ id }} }}", data)) + .await? + .assert_success(); + Ok(()) + } +} + +#[test_suite(only(MySql("8")))] +mod bytes { + fn schema_bytes() -> String { + let schema = indoc! { + r#" + model Parent { + #id(id, Int, @id) + + childId Int? @unique + child Child? @relation(fields: [childId], references: [id]) + } + + model Child { + #id(id, Int, @id) + bit Bytes @test.Bit(8) + bin Bytes @test.Binary(4) + vBin Bytes @test.VarBinary(5) + blob Bytes @test.Blob + tBlob Bytes @test.TinyBlob + mBlob Bytes @test.MediumBlob + lBlob Bytes @test.LongBlob + + parent Parent? + }"# + }; + + schema.to_owned() + } + + // "Mysql native bytes types" should "work" + #[connector_test(schema(schema_bytes))] + async fn native_bytes(runner: Runner) -> TestResult<()> { + create_row( + &runner, + r#"{ + id: 1, + child: { create: { + id: 1, + bit: "dA==" + bin: "dGVzdA==" + vBin: "dGVzdA==" + blob: "dGVzdA==" + tBlob: "dGVzdA==" + mBlob: "dGVzdA==" + lBlob: "dGVzdA==" + }} + }"#, + ) + .await?; + + insta::assert_snapshot!( + run_query!(&runner, r#"{ findManyParent { + id + child { + bit + bin + vBin + blob + tBlob + mBlob + lBlob + } + }}"#), + @r###"{"data":{"findManyParent":[{"id":1,"child":{"bit":"dA==","bin":"dGVzdA==","vBin":"dGVzdA==","blob":"dGVzdA==","tBlob":"dGVzdA==","mBlob":"dGVzdA==","lBlob":"dGVzdA=="}}]}}"### + ); + + Ok(()) + } + + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { + runner + .query(format!("mutation {{ createOneParent(data: {}) {{ id }} }}", data)) + .await? + .assert_success(); + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/nested_atomic_number_ops.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/nested_atomic_number_ops.rs index 32c2092a9339..c325fccb6d64 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/nested_atomic_number_ops.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/nested_atomic_number_ops.rs @@ -324,7 +324,7 @@ mod atomic_number_ops { } // "A nested updateOne mutation" should "correctly apply all number operations for Int" - #[connector_test(schema(schema_3), exclude(MongoDb, Postgres("pg.js", "neon.js")))] + #[connector_test(schema(schema_3), exclude(MongoDb))] async fn nested_update_float_ops(runner: Runner) -> TestResult<()> { create_test_model(&runner, 1, None, None).await?; create_test_model(&runner, 2, None, Some("5.5")).await?; diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs index 60d4cd6801fa..ac07d9b71546 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/connector_tag/mod.rs @@ -326,6 +326,11 @@ impl ConnectorVersion { | Self::Sqlite(Some(SqliteVersion::LibsqlJsWasm)) ) } + + /// Returns `true` if the connector version is [`MySql`]. + pub(crate) fn is_mysql(&self) -> bool { + matches!(self, Self::MySql(..)) + } } impl fmt::Display for ConnectorVersion { @@ -378,12 +383,12 @@ pub(crate) fn should_run( let exclusions = exclude .iter() - .filter_map(|c| ConnectorVersion::try_from(*c).ok()) + .map(|c| ConnectorVersion::try_from(*c).unwrap()) .collect::>(); let inclusions = only .iter() - .filter_map(|c| ConnectorVersion::try_from(*c).ok()) + .map(|c| ConnectorVersion::try_from(*c).unwrap()) .collect::>(); for exclusion in exclusions.iter() { diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs index 7295972f9812..5390ee975d89 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/datamodel_rendering/mod.rs @@ -4,11 +4,13 @@ mod sql_renderer; pub use mongodb_renderer::*; pub use sql_renderer::*; -use crate::{connection_string, templating, DatamodelFragment, IdFragment, M2mFragment, CONFIG}; +use crate::{ + connection_string, templating, ConnectorVersion, DatamodelFragment, IdFragment, M2mFragment, MySqlVersion, CONFIG, +}; use indoc::indoc; use itertools::Itertools; use once_cell::sync::Lazy; -use psl::ALL_PREVIEW_FEATURES; +use psl::{PreviewFeature, ALL_PREVIEW_FEATURES}; use regex::Regex; /// Test configuration, loaded once at runtime. @@ -37,7 +39,7 @@ pub fn render_test_datamodel( isolation_level: Option<&'static str>, ) -> String { let (tag, version) = CONFIG.test_connector().unwrap(); - let preview_features = render_preview_features(excluded_features); + let preview_features = render_preview_features(excluded_features, &version); let is_multi_schema = !db_schemas.is_empty(); @@ -89,8 +91,13 @@ fn process_template(template: String, renderer: Box) -> S }) } -fn render_preview_features(excluded_features: &[&str]) -> String { - let excluded_features: Vec<_> = excluded_features.iter().map(|f| format!(r#""{f}""#)).collect(); +fn render_preview_features(excluded_features: &[&str], version: &ConnectorVersion) -> String { + let mut excluded_features: Vec<_> = excluded_features.iter().map(|f| format!(r#""{f}""#)).collect(); + + // TODO: Remove this once we are able to have version speficic preview features. + if version.is_mysql() && !matches!(version, ConnectorVersion::MySql(Some(MySqlVersion::V8))) { + excluded_features.push(format!(r#""{}""#, PreviewFeature::RelationJoins)); + } ALL_PREVIEW_FEATURES .active_features() diff --git a/query-engine/connectors/mongodb-query-connector/src/value.rs b/query-engine/connectors/mongodb-query-connector/src/value.rs index b0d4946f23cf..a9f79e941f7e 100644 --- a/query-engine/connectors/mongodb-query-connector/src/value.rs +++ b/query-engine/connectors/mongodb-query-connector/src/value.rs @@ -152,7 +152,7 @@ impl IntoBson for (&MongoDbType, PrismaValue) { // Double (MongoDbType::Double, PrismaValue::Int(i)) => Bson::Double(i as f64), - (MongoDbType::Double, PrismaValue::Float(f)) => Bson::Double(f.to_f64().convert(expl::MONGO_DOUBLE)?), + (MongoDbType::Double, PrismaValue::Float(f)) => bigdecimal_to_bson_double(f)?, (MongoDbType::Double, PrismaValue::BigInt(b)) => Bson::Double(b.to_f64().convert(expl::MONGO_DOUBLE)?), // Int @@ -189,6 +189,15 @@ impl IntoBson for (&MongoDbType, PrismaValue) { increment: 0, }), + (MongoDbType::Json, PrismaValue::Json(json)) => { + let val: Value = serde_json::from_str(&json)?; + + Bson::try_from(val).map_err(|_| MongoError::ConversionError { + from: "Stringified JSON".to_owned(), + to: "Mongo BSON (extJSON)".to_owned(), + })? + } + // Unhandled conversions (mdb_type, p_val) => { return Err(MongoError::ConversionError { @@ -231,15 +240,7 @@ impl IntoBson for (&TypeIdentifier, PrismaValue) { (TypeIdentifier::BigInt, PrismaValue::Float(dec)) => Bson::Int64(dec.to_i64().convert(expl::MONGO_I64)?), // Float - (TypeIdentifier::Float, PrismaValue::Float(dec)) => { - // We don't have native support for float numbers (yet) - // so we need to do this, see https://docs.rs/bigdecimal/latest/bigdecimal/index.html - let dec_str = dec.to_string(); - let f64_val = dec_str.parse::().ok(); - let converted = f64_val.convert(expl::MONGO_DOUBLE)?; - - Bson::Double(converted) - } + (TypeIdentifier::Float, PrismaValue::Float(dec)) => bigdecimal_to_bson_double(dec)?, (TypeIdentifier::Float, PrismaValue::Int(i)) => Bson::Double(i.to_f64().convert(expl::MONGO_DOUBLE)?), (TypeIdentifier::Float, PrismaValue::BigInt(i)) => Bson::Double(i.to_f64().convert(expl::MONGO_DOUBLE)?), @@ -474,6 +475,14 @@ fn format_opt(opt: Option) -> String { } } +fn bigdecimal_to_bson_double(dec: BigDecimal) -> crate::Result { + let dec_str = dec.to_string(); + let f64_val = dec_str.parse::().ok(); + let converted = f64_val.convert(expl::MONGO_DOUBLE)?; + + Ok(Bson::Double(converted)) +} + /// Explanation constants for conversion errors. mod expl { #![allow(dead_code)] diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index b467fcd277bf..354ec5bc0887 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -27,7 +27,6 @@ uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } -hex = "0.4" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] quaint.workspace = true diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs index 153354c518c8..0d697b97dc65 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs @@ -18,12 +18,29 @@ pub(crate) fn coerce_record_with_json_relation( ) -> crate::Result<()> { for (val_idx, kind) in indexes { let val = record.values.get_mut(*val_idx).unwrap(); - // TODO(perf): Find ways to avoid serializing and deserializing multiple times. - let json_val: serde_json::Value = serde_json::from_str(val.as_json().unwrap()).unwrap(); - *val = match kind { - IndexedSelection::Relation(rs) => coerce_json_relation_to_pv(json_val, rs)?, - IndexedSelection::Virtual(name) => coerce_json_virtual_field_to_pv(name, json_val)?, + match kind { + IndexedSelection::Relation(rs) => { + match val { + PrismaValue::Null if rs.field.is_list() => { + *val = PrismaValue::List(vec![]); + } + PrismaValue::Null if rs.field.is_optional() => { + continue; + } + val => { + // TODO(perf): Find ways to avoid serializing and deserializing multiple times. + let json_val: serde_json::Value = serde_json::from_str(val.as_json().unwrap()).unwrap(); + + *val = coerce_json_relation_to_pv(json_val, rs)?; + } + } + } + IndexedSelection::Virtual(name) => { + let json_val: serde_json::Value = serde_json::from_str(val.as_json().unwrap()).unwrap(); + + *val = coerce_json_virtual_field_to_pv(name, json_val)? + } }; } @@ -34,6 +51,9 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection) let relations = rs.relations().collect_vec(); match value { + // Some versions of MySQL return null when offsetting by more than the number of rows available. + serde_json::Value::Null if rs.field.is_list() => Ok(PrismaValue::List(vec![])), + serde_json::Value::Null if rs.field.is_optional() => Ok(PrismaValue::Null), // one-to-many serde_json::Value::Array(values) if rs.field.is_list() => { let iter = values.into_iter().filter_map(|value| { @@ -57,21 +77,6 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection) Ok(PrismaValue::List(iter.collect::>>()?)) } - // to-one - serde_json::Value::Array(values) => { - let coerced = values - .into_iter() - .next() - .map(|value| coerce_json_relation_to_pv(value, rs)); - - // TODO(HACK): We probably want to update the sql builder instead to not aggregate to-one relations as array - // If the array is empty, it means there's no relations, so we coerce it to - if let Some(val) = coerced { - val - } else { - Ok(PrismaValue::Null) - } - } serde_json::Value::Object(obj) => { let mut map: Vec<(String, PrismaValue)> = Vec::with_capacity(obj.len()); let related_model = rs.field.related_model(); @@ -133,6 +138,17 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel Ok(PrismaValue::Float(bd)) } + TypeIdentifier::Boolean => { + let err = + || build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier())); + let i = n.as_i64().ok_or_else(err)?; + + match i { + 0 => Ok(PrismaValue::Boolean(false)), + 1 => Ok(PrismaValue::Boolean(true)), + _ => Err(err()), + } + } _ => Err(build_conversion_error( sf, &format!("Number({n})"), @@ -154,7 +170,7 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel Ok(PrismaValue::DateTime(res)) } - TypeIdentifier::Decimal => { + TypeIdentifier::Decimal | TypeIdentifier::Float => { let res = parse_decimal(&s).map_err(|err| { build_conversion_error_with_reason( sf, @@ -175,8 +191,7 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel ) })?)), TypeIdentifier::Bytes => { - // We skip the first two characters because there's the \x prefix. - let bytes = hex::decode(&s[2..]).map_err(|err| { + let bytes = sf.parse_json_bytes(&s).map_err(|err| { build_conversion_error_with_reason( sf, &format!("String({s})"), diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 07287fee303c..f9041c6dcd78 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -42,7 +42,7 @@ pub(crate) async fn get_single_record_joins( &field_names, ); - let query = query_builder::select::SelectBuilder::default().build( + let query = query_builder::select::SelectBuilder::build( QueryArguments::from((model.clone(), filter.clone())), selected_fields, ctx, @@ -155,7 +155,7 @@ pub(crate) async fn get_many_records_joins( _ => (), }; - let query = query_builder::select::SelectBuilder::default().build(query_arguments.clone(), selected_fields, ctx); + let query = query_builder::select::SelectBuilder::build(query_arguments.clone(), selected_fields, ctx); for item in conn.filter(query.into(), meta.as_slice(), ctx).await?.into_iter() { let mut record = Record::from(item); diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs new file mode 100644 index 000000000000..5b86bfaa581b --- /dev/null +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/lateral.rs @@ -0,0 +1,206 @@ +use super::*; + +use crate::{ + context::Context, + filter::alias::{Alias, AliasMode}, + model_extensions::AsColumn, +}; + +use quaint::ast::*; +use query_structure::*; + +/// Select builder for joined queries. Relations are resolved using LATERAL JOINs. +#[derive(Debug, Default)] +pub(crate) struct LateralJoinSelectBuilder { + alias: Alias, +} + +impl JoinSelectBuilder for LateralJoinSelectBuilder { + /// Builds a SELECT statement for the given query arguments and selected fields. + /// + /// ```sql + /// SELECT + /// id, + /// name + /// FROM "User" + /// LEFT JOIN LATERAL ( + /// SELECT JSON_OBJECT(<...>) FROM "Post" WHERE "Post"."authorId" = "User"."id + /// ) as "post" ON TRUE + /// ``` + fn build(&mut self, args: QueryArguments, selected_fields: &FieldSelection, ctx: &Context<'_>) -> Select<'static> { + let (select, parent_alias) = self.build_default_select(&args, ctx); + let select = self.with_selection(select, selected_fields, parent_alias, ctx); + let select = self.with_relations(select, selected_fields.relations(), parent_alias, ctx); + + self.with_virtual_selections(select, selected_fields.virtuals(), parent_alias, ctx) + } + + fn build_selection<'a>( + &mut self, + select: Select<'a>, + field: &SelectedField, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + match field { + SelectedField::Scalar(sf) => select.column( + sf.as_column(ctx) + .table(parent_alias.to_table_string()) + .set_is_selected(true), + ), + SelectedField::Relation(rs) => { + let table_name = match rs.field.relation().is_many_to_many() { + true => m2m_join_alias_name(&rs.field), + false => join_alias_name(&rs.field), + }; + + select.value(Column::from((table_name, JSON_AGG_IDENT)).alias(rs.field.name().to_owned())) + } + _ => select, + } + } + + fn add_to_one_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let (subselect, child_alias) = + self.build_to_one_select(rs, parent_alias, |expr: Expression<'_>| expr.alias(JSON_AGG_IDENT), ctx); + let subselect = self.with_relations(subselect, rs.relations(), child_alias, ctx); + let subselect = self.with_virtual_selections(subselect, rs.virtuals(), child_alias, ctx); + + let join_table = Table::from(subselect).alias(join_alias_name(&rs.field)); + // LEFT JOIN LATERAL ( ) AS ON TRUE + select.left_join(join_table.on(ConditionTree::single(true.raw())).lateral()) + } + + fn add_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let join_table_alias = join_alias_name(&rs.field); + let join_table = Table::from(self.build_to_many_select(rs, parent_alias, ctx)).alias(join_table_alias); + + // LEFT JOIN LATERAL ( ) AS ON TRUE + select.left_join(join_table.on(ConditionTree::single(true.raw())).lateral()) + } + + fn add_many_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let m2m_join = self.build_m2m_join(rs, parent_alias, ctx); + + select.left_join(m2m_join) + } + + fn add_virtual_selection<'a>( + &mut self, + select: Select<'a>, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let relation_count_select = self.build_virtual_select(vs, parent_alias, ctx); + let table = Table::from(relation_count_select).alias(relation_count_alias_name(vs.relation_field())); + + select.left_join_lateral(table.on(ConditionTree::single(true.raw()))) + } + + fn build_json_obj_fn( + &mut self, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Expression<'static> { + let build_obj_params = rs + .selections + .iter() + .filter_map(|field| match field { + SelectedField::Scalar(sf) => Some(( + Cow::from(sf.db_name().to_owned()), + Expression::from(sf.as_column(ctx).table(parent_alias.to_table_string())), + )), + SelectedField::Relation(rs) => { + let table_name = match rs.field.relation().is_many_to_many() { + true => m2m_join_alias_name(&rs.field), + false => join_alias_name(&rs.field), + }; + + Some(( + Cow::from(rs.field.name().to_owned()), + Expression::from(Column::from((table_name, JSON_AGG_IDENT))), + )) + } + _ => None, + }) + .chain(self.build_json_obj_virtual_selection(rs.virtuals(), parent_alias, ctx)) + .collect(); + + json_build_object(build_obj_params).into() + } + + fn build_virtual_expr( + &mut self, + vs: &VirtualSelection, + _parent_alias: Alias, + _ctx: &Context<'_>, + ) -> Expression<'static> { + let rf = vs.relation_field(); + + coalesce([ + Expression::from(Column::from((relation_count_alias_name(rf), vs.db_alias()))), + Expression::from(0.raw()), + ]) + .into() + } + + fn next_alias(&mut self) -> Alias { + self.alias = self.alias.inc(AliasMode::Table); + self.alias + } +} + +impl LateralJoinSelectBuilder { + fn build_m2m_join<'a>(&mut self, rs: &RelationSelection, parent_alias: Alias, ctx: &Context<'_>) -> JoinData<'a> { + let rf = rs.field.clone(); + let m2m_table_alias = self.next_alias(); + let m2m_join_alias = self.next_alias(); + let outer_alias = self.next_alias(); + + let m2m_join_data = Table::from(self.build_to_many_select(rs, m2m_table_alias, ctx)) + .alias(m2m_join_alias.to_table_string()) + .on(ConditionTree::single(true.raw())) + .lateral(); + + let child_table = rf.as_table(ctx).alias(m2m_table_alias.to_table_string()); + + let inner = Select::from_table(child_table) + .value(Column::from((m2m_join_alias.to_table_string(), JSON_AGG_IDENT))) + .left_join(m2m_join_data) // join m2m table + .with_m2m_join_conditions(&rf.related_field(), m2m_table_alias, parent_alias, ctx) // adds join condition to the child table + // TODO: avoid clone filter + .with_filters(rs.args.filter.clone(), Some(m2m_join_alias), ctx) // adds query filters + .with_ordering(&rs.args, Some(m2m_join_alias.to_table_string()), ctx) // adds ordering stmts + .with_pagination(rs.args.take_abs(), rs.args.skip) + .comment("inner"); // adds pagination + + let outer = Select::from_table(Table::from(inner).alias(outer_alias.to_table_string())) + .value(json_agg()) + .comment("outer"); + + Table::from(outer) + .alias(m2m_join_alias_name(&rf)) + .on(ConditionTree::single(true.raw())) + .lateral() + } +} diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs similarity index 55% rename from query-engine/connectors/sql-query-connector/src/query_builder/select.rs rename to query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs index 27a4795789e4..d878ad63ec18 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs @@ -1,89 +1,139 @@ -use std::{borrow::Cow, collections::BTreeMap}; +mod lateral; +mod subquery; + +use std::borrow::Cow; use tracing::Span; +use psl::datamodel_connector::{ConnectorCapability, Flavour}; +use quaint::prelude::*; +use query_structure::*; + use crate::{ context::Context, - filter::alias::{Alias, AliasMode}, - model_extensions::{AsColumn, AsColumns, AsTable, ColumnIterator, RelationFieldExt}, + filter::alias::Alias, + model_extensions::{AsColumns, AsTable, ColumnIterator, RelationFieldExt}, ordering::OrderByBuilder, sql_trace::SqlTraceComment, }; -use quaint::prelude::*; -use query_structure::*; +use self::{lateral::LateralJoinSelectBuilder, subquery::SubqueriesSelectBuilder}; -pub const JSON_AGG_IDENT: &str = "__prisma_data__"; +pub(crate) const JSON_AGG_IDENT: &str = "__prisma_data__"; -#[derive(Debug, Default)] -pub(crate) struct SelectBuilder { - alias: Alias, -} +pub(crate) struct SelectBuilder; impl SelectBuilder { - pub(crate) fn next_alias(&mut self) -> Alias { - self.alias = self.alias.inc(AliasMode::Table); - self.alias + pub fn build(args: QueryArguments, selected_fields: &FieldSelection, ctx: &Context<'_>) -> Select<'static> { + if supports_lateral_join(&args) { + LateralJoinSelectBuilder::default().build(args, selected_fields, ctx) + } else { + SubqueriesSelectBuilder::default().build(args, selected_fields, ctx) + } } +} - pub(crate) fn build( +pub(crate) trait JoinSelectBuilder { + /// Build the select query for the given query arguments and selected fields. + /// This is the entry point for building a select query. `build_default_select` can be used to get a default select query. + fn build(&mut self, args: QueryArguments, selected_fields: &FieldSelection, ctx: &Context<'_>) -> Select<'static>; + /// Adds to `select` the SQL statements to fetch a 1-1 relation. + fn add_to_one_relation<'a>( &mut self, - args: QueryArguments, - selected_fields: &FieldSelection, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, ctx: &Context<'_>, - ) -> Select<'static> { - let table_alias = self.next_alias(); - let table = args.model().as_table(ctx).alias(table_alias.to_table_string()); - - // SELECT ... FROM Table "t1" - let select = Select::from_table(table) - .with_selection(selected_fields, table_alias, ctx) - .with_ordering(&args, Some(table_alias.to_table_string()), ctx) - .with_pagination(args.take_abs(), args.skip) - .with_filters(args.filter, Some(table_alias), ctx) - .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id); - - // Adds joins for relations - let select = self.with_related_queries(select, selected_fields.relations(), table_alias, ctx); - - // Adds joins for relation aggregations. Other potential future kinds of virtual fields - // might or might not require joins and might be processed differently. - self.with_relation_aggregation_queries(select, selected_fields.virtuals(), table_alias, ctx) - } + ) -> Select<'a>; + /// Adds to `select` the SQL statements to fetch a 1-m relation. + fn add_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a>; + /// Adds to `select` the SQL statements to fetch a m-n relation. + fn add_many_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a>; + fn add_virtual_selection<'a>( + &mut self, + select: Select<'a>, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a>; + /// Build the top-level selection set + fn build_selection<'a>( + &mut self, + select: Select<'a>, + field: &SelectedField, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a>; + fn build_json_obj_fn( + &mut self, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Expression<'static>; + fn build_virtual_expr( + &mut self, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Expression<'static>; + /// Get the next alias for a table. + fn next_alias(&mut self) -> Alias; - fn with_related_queries<'a, 'b>( + fn with_selection<'a>( &mut self, - input: Select<'a>, - relation_selections: impl Iterator, + select: Select<'a>, + selected_fields: &FieldSelection, parent_alias: Alias, ctx: &Context<'_>, ) -> Select<'a> { - relation_selections.fold(input, |acc, rs| self.with_related_query(acc, rs, parent_alias, ctx)) + let select = selected_fields.selections().fold(select, |acc, selection| { + self.build_selection(acc, selection, parent_alias, ctx) + }); + + self.build_json_obj_virtual_selection(selected_fields.virtuals(), parent_alias, ctx) + .into_iter() + .fold(select, |acc, (alias, expr)| acc.value(expr.alias(alias))) } - fn with_related_query<'a>( + /// Builds the core select for a 1-1 relation. + fn build_to_one_select( &mut self, - select: Select<'a>, rs: &RelationSelection, parent_alias: Alias, + selection_modifier: impl FnOnce(Expression<'static>) -> Expression<'static>, ctx: &Context<'_>, - ) -> Select<'a> { - if rs.field.relation().is_many_to_many() { - // m2m relations need to left join on the relation table first - let m2m_join = self.build_m2m_join(rs, parent_alias, ctx); + ) -> (Select<'static>, Alias) { + let rf = &rs.field; + let child_table_alias = self.next_alias(); + let table = rs + .field + .related_field() + .as_table(ctx) + .alias(child_table_alias.to_table_string()); + let json_expr = self.build_json_obj_fn(rs, child_table_alias, ctx); - select.left_join(m2m_join) - } else { - let join_table_alias = join_alias_name(&rs.field); - let join_table = - Table::from(self.build_related_query_select(rs, parent_alias, ctx)).alias(join_table_alias); + let select = Select::from_table(table) + .with_join_conditions(rf, parent_alias, child_table_alias, ctx) + .with_filters(rs.args.filter.clone(), Some(child_table_alias), ctx) + .value(selection_modifier(json_expr)) + .limit(1); - // LEFT JOIN LATERAL ( ) AS ON TRUE - select.left_join(join_table.on(ConditionTree::single(true.raw())).lateral()) - } + (select, child_table_alias) } - fn build_related_query_select( + /// Builds the core select for a 1-m relation. + fn build_to_many_select( &mut self, rs: &RelationSelection, parent_alias: Alias, @@ -106,13 +156,9 @@ impl SelectBuilder { // SELECT JSON_BUILD_OBJECT() FROM ( ) let inner = Select::from_table(Table::from(root).alias(root_alias.to_table_string())) - .value(build_json_obj_fn(rs, ctx, root_alias).alias(JSON_AGG_IDENT)); - - // LEFT JOIN LATERAL () AS ON TRUE - let inner = self.with_related_queries(inner, rs.relations(), root_alias, ctx); - - // LEFT JOIN LATERAL ( ) ON TRUE - let inner = self.with_relation_aggregation_queries(inner, rs.virtuals(), root_alias, ctx); + .value(self.build_json_obj_fn(rs, root_alias, ctx).alias(JSON_AGG_IDENT)); + let inner = self.with_relations(inner, rs.relations(), root_alias, ctx); + let inner = self.with_virtual_selections(inner, rs.virtuals(), root_alias, ctx); let linking_fields = rs.field.related_field().linking_fields(); @@ -140,6 +186,14 @@ impl SelectBuilder { let inner = inner.with_columns(inner_selection.into()).comment("inner select"); + let middle_take = match connector_flavour(&rs.args) { + // On MySQL, using LIMIT makes the ordering of the JSON_AGG working. Beware, this is undocumented behavior. + // Note: Ideally, this should live in the MySQL select builder, but it's currently the only implementation difference + // between MySQL and Postgres, so we keep it here for now to avoid code duplication. + Flavour::Mysql if !rs.args.order_by.is_empty() => rs.args.take_abs().or(Some(i64::MAX)), + _ => rs.args.take_abs(), + }; + let middle = Select::from_table(Table::from(inner).alias(inner_alias.to_table_string())) // SELECT . .column(Column::from((inner_alias.to_table_string(), JSON_AGG_IDENT))) @@ -148,7 +202,7 @@ impl SelectBuilder { // WHERE ... .with_filters(rs.args.filter.clone(), Some(inner_alias), ctx) // LIMIT $1 OFFSET $2 - .with_pagination(rs.args.take_abs(), rs.args.skip) + .with_pagination(middle_take, rs.args.skip) .comment("middle select"); // SELECT COALESCE(JSON_AGG(), '[]') AS FROM ( ) as @@ -158,78 +212,94 @@ impl SelectBuilder { } } - fn build_m2m_join<'a>(&mut self, rs: &RelationSelection, parent_alias: Alias, ctx: &Context<'_>) -> JoinData<'a> { - let rf = rs.field.clone(); - let m2m_table_alias = self.next_alias(); - let m2m_join_alias = self.next_alias(); - let outer_alias = self.next_alias(); - - let left_columns = rf.related_field().m2m_columns(ctx); - let right_columns = ModelProjection::from(rf.model().primary_identifier()).as_columns(ctx); - - let join_conditions = - build_join_conditions((left_columns.into(), m2m_table_alias), (right_columns, parent_alias)); - - let m2m_join_data = Table::from(self.build_related_query_select(rs, m2m_table_alias, ctx)) - .alias(m2m_join_alias.to_table_string()) - .on(ConditionTree::single(true.raw())) - .lateral(); - - let child_table = rf.as_table(ctx).alias(m2m_table_alias.to_table_string()); - - let inner = Select::from_table(child_table) - .value(Column::from((m2m_join_alias.to_table_string(), JSON_AGG_IDENT))) - .left_join(m2m_join_data) // join m2m table - .and_where(join_conditions) // adds join condition to the child table - .with_ordering(&rs.args, Some(m2m_join_alias.to_table_string()), ctx) // adds ordering stmts - .with_filters(rs.args.filter.clone(), Some(m2m_join_alias), ctx) // adds query filters // TODO: avoid clone filter - .with_pagination(rs.args.take_abs(), rs.args.skip) - .comment("inner"); // adds pagination - - let outer = Select::from_table(Table::from(inner).alias(outer_alias.to_table_string())) - .value(json_agg()) - .comment("outer"); - - Table::from(outer) - .alias(m2m_join_alias_name(&rf)) - .on(ConditionTree::single(true.raw())) - .lateral() + fn with_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + match (rs.field.is_list(), rs.field.relation().is_many_to_many()) { + (true, true) => self.add_many_to_many_relation(select, rs, parent_alias, ctx), + (true, false) => self.add_to_many_relation(select, rs, parent_alias, ctx), + (false, _) => self.add_to_one_relation(select, rs, parent_alias, ctx), + } + } + + fn with_relations<'a, 'b>( + &mut self, + input: Select<'a>, + relation_selections: impl Iterator, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + relation_selections.fold(input, |acc, rs| self.with_relation(acc, rs, parent_alias, ctx)) + } + + fn build_default_select(&mut self, args: &QueryArguments, ctx: &Context<'_>) -> (Select<'static>, Alias) { + let table_alias = self.next_alias(); + let table = args.model().as_table(ctx).alias(table_alias.to_table_string()); + + // SELECT ... FROM Table "t1" + let select = Select::from_table(table) + .with_ordering(args, Some(table_alias.to_table_string()), ctx) + .with_filters(args.filter.clone(), Some(table_alias), ctx) + .with_pagination(args.take_abs(), args.skip) + .append_trace(&Span::current()) + .add_trace_id(ctx.trace_id); + + (select, table_alias) } - fn with_relation_aggregation_queries<'a, 'b>( + fn with_virtual_selections<'a, 'b>( &mut self, select: Select<'a>, selections: impl Iterator, parent_alias: Alias, ctx: &Context<'_>, ) -> Select<'a> { - selections.fold(select, |acc, vs| { - self.with_relation_aggregation_query(acc, vs, parent_alias, ctx) - }) + selections.fold(select, |acc, vs| self.add_virtual_selection(acc, vs, parent_alias, ctx)) } - fn with_relation_aggregation_query<'a>( + fn build_virtual_select( &mut self, - select: Select<'a>, vs: &VirtualSelection, parent_alias: Alias, ctx: &Context<'_>, - ) -> Select<'a> { + ) -> Select<'static> { match vs { VirtualSelection::RelationCount(rf, filter) => { - let table_alias = relation_count_alias_name(rf); - - let relation_count_select = if rf.relation().is_many_to_many() { + if rf.relation().is_many_to_many() { self.build_relation_count_query_m2m(vs.db_alias(), rf, filter, parent_alias, ctx) } else { self.build_relation_count_query(vs.db_alias(), rf, filter, parent_alias, ctx) - }; + } + } + } + } - let table = Table::from(relation_count_select).alias(table_alias); + fn build_json_obj_virtual_selection<'a>( + &mut self, + virtual_fields: impl Iterator, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Vec<(Cow<'static, str>, Expression<'static>)> { + let mut selected_objects = std::collections::BTreeMap::new(); - select.left_join_lateral(table.on(ConditionTree::single(true.raw()))) - } + for vs in virtual_fields { + let (object_name, field_name) = vs.serialized_name(); + let virtual_expr = self.build_virtual_expr(vs, parent_alias, ctx); + + selected_objects + .entry(object_name) + .or_insert(Vec::new()) + .push((field_name.to_owned().into(), virtual_expr)); } + + selected_objects + .into_iter() + .map(|(name, fields)| (name.into(), json_build_object(fields).into())) + .collect() } fn build_relation_count_query<'a>( @@ -274,7 +344,10 @@ impl SelectBuilder { let m2m_join_conditions = { let left_columns = rf.join_columns(ctx); let right_columns = ModelProjection::from(rf.related_field().linking_fields()).as_columns(ctx); - build_join_conditions((left_columns, m2m_table_alias), (right_columns, related_table_alias)) + build_join_conditions( + (left_columns, Some(m2m_table_alias)), + (right_columns, Some(related_table_alias)), + ) }; let m2m_join_data = rf @@ -285,7 +358,10 @@ impl SelectBuilder { let aggregation_join_conditions = { let left_columns = rf.related_field().m2m_columns(ctx); let right_columns = ModelProjection::from(rf.model().primary_identifier()).as_columns(ctx); - build_join_conditions((left_columns.into(), m2m_table_alias), (right_columns, parent_alias)) + build_join_conditions( + (left_columns.into(), Some(m2m_table_alias)), + (right_columns, Some(parent_alias)), + ) }; let select = Select::from_table(related_table) @@ -298,19 +374,24 @@ impl SelectBuilder { } } -trait SelectBuilderExt<'a> { +pub(crate) trait SelectBuilderExt<'a> { fn with_filters(self, filter: Option, parent_alias: Option, ctx: &Context<'_>) -> Select<'a>; fn with_pagination(self, take: Option, skip: Option) -> Select<'a>; fn with_ordering(self, args: &QueryArguments, parent_alias: Option, ctx: &Context<'_>) -> Select<'a>; fn with_join_conditions( self, rf: &RelationField, - parent_alias: Alias, - child_alias: Alias, + left_alias: Alias, + right_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a>; + fn with_m2m_join_conditions( + self, + rf: &RelationField, + left_alias: Alias, + right_alias: Alias, ctx: &Context<'_>, ) -> Select<'a>; - fn with_selection(self, selected_fields: &FieldSelection, table_alias: Alias, ctx: &Context<'_>) -> Select<'a>; - fn with_virtuals_from_selection(self, selected_fields: &FieldSelection) -> Select<'a>; fn with_columns(self, columns: ColumnIterator) -> Select<'a>; } @@ -364,45 +445,21 @@ impl<'a> SelectBuilderExt<'a> for Select<'a> { fn with_join_conditions( self, rf: &RelationField, - parent_alias: Alias, - child_alias: Alias, + left_alias: Alias, + right_alias: Alias, ctx: &Context<'_>, ) -> Select<'a> { - let join_columns = rf.join_columns(ctx); - let related_join_columns = ModelProjection::from(rf.related_field().linking_fields()).as_columns(ctx); - - let conditions = build_join_conditions((join_columns, parent_alias), (related_join_columns, child_alias)); - - // WHERE Parent.id = Child.id - self.and_where(conditions) + self.and_where(rf.join_conditions(Some(left_alias), Some(right_alias), ctx)) } - fn with_selection(self, selected_fields: &FieldSelection, table_alias: Alias, ctx: &Context<'_>) -> Select<'a> { - selected_fields - .selections() - .fold(self, |acc, selection| match selection { - SelectedField::Scalar(sf) => acc.column( - sf.as_column(ctx) - .table(table_alias.to_table_string()) - .set_is_selected(true), - ), - SelectedField::Relation(rs) => { - let table_name = match rs.field.relation().is_many_to_many() { - true => m2m_join_alias_name(&rs.field), - false => join_alias_name(&rs.field), - }; - - acc.value(Column::from((table_name, JSON_AGG_IDENT)).alias(rs.field.name().to_owned())) - } - _ => acc, - }) - .with_virtuals_from_selection(selected_fields) - } - - fn with_virtuals_from_selection(self, selected_fields: &FieldSelection) -> Select<'a> { - build_virtual_selection(selected_fields.virtuals()) - .into_iter() - .fold(self, |select, (alias, expr)| select.value(expr.alias(alias))) + fn with_m2m_join_conditions( + self, + rf: &RelationField, + left_alias: Alias, + right_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + self.and_where(rf.m2m_join_conditions(Some(left_alias), Some(right_alias), ctx)) } fn with_columns(self, columns: ColumnIterator) -> Select<'a> { @@ -410,51 +467,45 @@ impl<'a> SelectBuilderExt<'a> for Select<'a> { } } -fn build_join_conditions( - (left_columns, left_alias): (ColumnIterator, Alias), - (right_columns, right_alias): (ColumnIterator, Alias), -) -> ConditionTree<'static> { - left_columns - .zip(right_columns) - .fold(None::, |acc, (a, b)| { - let a = a.table(left_alias.to_table_string()); - let b = b.table(right_alias.to_table_string()); - let condition = a.equals(b); - - match acc { - Some(acc) => Some(acc.and(condition)), - None => Some(condition.into()), - } - }) - .unwrap() +pub(crate) trait JoinConditionExt { + fn join_conditions( + &self, + left_alias: Option, + right_alias: Option, + ctx: &Context<'_>, + ) -> ConditionTree<'static>; + fn m2m_join_conditions( + &self, + left_alias: Option, + right_alias: Option, + ctx: &Context<'_>, + ) -> ConditionTree<'static>; } -fn build_json_obj_fn(rs: &RelationSelection, ctx: &Context<'_>, root_alias: Alias) -> Function<'static> { - let build_obj_params = rs - .selections - .iter() - .filter_map(|f| match f { - SelectedField::Scalar(sf) => Some(( - Cow::from(sf.db_name().to_owned()), - Expression::from(sf.as_column(ctx).table(root_alias.to_table_string())), - )), - SelectedField::Relation(rs) => { - let table_name = match rs.field.relation().is_many_to_many() { - true => m2m_join_alias_name(&rs.field), - false => join_alias_name(&rs.field), - }; - - Some(( - Cow::from(rs.field.name().to_owned()), - Expression::from(Column::from((table_name, JSON_AGG_IDENT))), - )) - } - _ => None, - }) - .chain(build_virtual_selection(rs.virtuals())) - .collect(); +impl JoinConditionExt for RelationField { + fn join_conditions( + &self, + left_alias: Option, + right_alias: Option, + ctx: &Context<'_>, + ) -> ConditionTree<'static> { + let left_columns = self.join_columns(ctx); + let right_columns = ModelProjection::from(self.related_field().linking_fields()).as_columns(ctx); + + build_join_conditions((left_columns, left_alias), (right_columns, right_alias)) + } - json_build_object(build_obj_params) + fn m2m_join_conditions( + &self, + left_alias: Option, + right_alias: Option, + ctx: &Context<'_>, + ) -> ConditionTree<'static> { + let left_columns = self.m2m_columns(ctx); + let right_columns = ModelProjection::from(self.related_model().primary_identifier()).as_columns(ctx); + + build_join_conditions((left_columns.into(), left_alias), (right_columns, right_alias)) + } } fn order_by_selection(rs: &RelationSelection) -> FieldSelection { @@ -515,41 +566,52 @@ fn m2m_join_alias_name(rf: &RelationField) -> String { format!("{}_{}_m2m", rf.model().name(), rf.name()) } +fn build_join_conditions( + left: (ColumnIterator, Option), + right: (ColumnIterator, Option), +) -> ConditionTree<'static> { + let (left_columns, left_alias) = left; + let (right_columns, right_alias) = right; + + left_columns + .into_iter() + .zip(right_columns) + .fold(None::, |acc, (a, b)| { + let a = a.opt_table(left_alias.map(|left| left.to_table_string())); + let b = b.opt_table(right_alias.map(|right| right.to_table_string())); + let condition = a.equals(b); + + match acc { + Some(acc) => Some(acc.and(condition)), + None => Some(condition.into()), + } + }) + .unwrap() +} + fn json_agg() -> Function<'static> { coalesce(vec![ json_array_agg(Column::from(JSON_AGG_IDENT)).into(), - Expression::from("[]".raw()), + Expression::from(Value::json(empty_json_array()).raw()), ]) .alias(JSON_AGG_IDENT) } -fn build_virtual_selection<'a>( - virtual_fields: impl Iterator, -) -> Vec<(Cow<'static, str>, Expression<'static>)> { - let mut selected_objects = BTreeMap::new(); +#[inline] +fn empty_json_array() -> serde_json::Value { + serde_json::Value::Array(Vec::new()) +} - for vs in virtual_fields { - match vs { - VirtualSelection::RelationCount(rf, _) => { - let (object_name, field_name) = vs.serialized_name(); - - let coalesce_args: Vec> = vec![ - Column::from((relation_count_alias_name(rf), vs.db_alias())).into(), - 0.raw().into(), - ]; - - selected_objects - .entry(object_name) - .or_insert(Vec::new()) - .push((field_name.to_owned().into(), coalesce(coalesce_args).into())); - } - } - } +fn connector_flavour(args: &QueryArguments) -> Flavour { + args.model().dm.schema.connector.flavour() +} - selected_objects - .into_iter() - .map(|(name, fields)| (name.into(), json_build_object(fields).into())) - .collect() +fn supports_lateral_join(args: &QueryArguments) -> bool { + args.model() + .dm + .schema + .connector + .has_capability(ConnectorCapability::LateralJoin) } fn relation_count_alias_name(rf: &RelationField) -> String { diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/subquery.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/subquery.rs new file mode 100644 index 000000000000..202d42780e8b --- /dev/null +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/subquery.rs @@ -0,0 +1,190 @@ +use super::*; + +use crate::{ + context::Context, + filter::alias::{Alias, AliasMode}, + model_extensions::*, +}; + +use quaint::ast::*; +use query_structure::*; + +/// Select builder for joined queries. Relations are resolved using correlated sub-queries. +#[derive(Debug, Default)] +pub(crate) struct SubqueriesSelectBuilder { + alias: Alias, +} + +impl JoinSelectBuilder for SubqueriesSelectBuilder { + /// Builds a SELECT statement for the given query arguments and selected fields. + /// + /// ```sql + /// SELECT + /// id, + /// name, + /// ( + /// SELECT JSON_OBJECT(<...>) FROM "Post" WHERE "Post"."authorId" = "User"."id + /// ) as `post` + /// FROM "User" + /// ``` + fn build(&mut self, args: QueryArguments, selected_fields: &FieldSelection, ctx: &Context<'_>) -> Select<'static> { + let (select, alias) = self.build_default_select(&args, ctx); + + self.with_selection(select, selected_fields, alias, ctx) + } + + fn build_selection<'a>( + &mut self, + select: Select<'a>, + field: &SelectedField, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + match field { + SelectedField::Scalar(sf) => select.column( + sf.as_column(ctx) + .table(parent_alias.to_table_string()) + .set_is_selected(true), + ), + SelectedField::Relation(rs) => self.with_relation(select, rs, parent_alias, ctx), + _ => select, + } + } + + fn add_to_one_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let (subselect, _) = self.build_to_one_select(rs, parent_alias, |x| x, ctx); + + select.value(Expression::from(subselect).alias(rs.field.name().to_owned())) + } + + fn add_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let subselect = self.build_to_many_select(rs, parent_alias, ctx); + + select.value(Expression::from(subselect).alias(rs.field.name().to_owned())) + } + + fn add_many_to_many_relation<'a>( + &mut self, + select: Select<'a>, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let subselect = self.build_m2m_select(rs, parent_alias, ctx); + + select.value(Expression::from(subselect).alias(rs.field.name().to_owned())) + } + + fn add_virtual_selection<'a>( + &mut self, + select: Select<'a>, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let virtual_select = self.build_virtual_select(vs, parent_alias, ctx); + let alias = relation_count_alias_name(vs.relation_field()); + + select.value(Expression::from(virtual_select).alias(alias)) + } + + fn build_json_obj_fn( + &mut self, + rs: &RelationSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Expression<'static> { + let virtuals = self.build_json_obj_virtual_selection(rs.virtuals(), parent_alias, ctx); + let build_obj_params = rs + .selections + .iter() + .filter_map(|field| match field { + SelectedField::Scalar(sf) => Some(( + Cow::from(sf.db_name().to_owned()), + Expression::from(sf.as_column(ctx).table(parent_alias.to_table_string())), + )), + SelectedField::Relation(rs) => Some(( + Cow::from(rs.field.name().to_owned()), + Expression::from(self.with_relation(Select::default(), rs, parent_alias, ctx)), + )), + _ => None, + }) + .chain(virtuals) + .collect(); + + json_build_object(build_obj_params).into() + } + + fn build_virtual_expr( + &mut self, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Expression<'static> { + coalesce([ + Expression::from(self.build_virtual_select(vs, parent_alias, ctx)), + Expression::from(0.raw()), + ]) + .into() + } + + fn next_alias(&mut self) -> Alias { + self.alias = self.alias.inc(AliasMode::Table); + self.alias + } +} + +impl SubqueriesSelectBuilder { + fn build_m2m_select<'a>(&mut self, rs: &RelationSelection, parent_alias: Alias, ctx: &Context<'_>) -> Select<'a> { + let rf = rs.field.clone(); + let m2m_table_alias = self.next_alias(); + let root_alias = self.next_alias(); + let outer_alias = self.next_alias(); + + let m2m_join_data = + rf.related_model() + .as_table(ctx) + .on(rf.m2m_join_conditions(Some(m2m_table_alias), None, ctx)); + + let m2m_table = rf.as_table(ctx).alias(m2m_table_alias.to_table_string()); + + let root = Select::from_table(m2m_table) + .inner_join(m2m_join_data) + .value(rf.related_model().as_table(ctx).asterisk()) + .with_ordering(&rs.args, None, ctx) // adds ordering stmts + // Keep join conditions _before_ user filters to ensure index is used first + .and_where( + rf.related_field() + .m2m_join_conditions(Some(m2m_table_alias), Some(parent_alias), ctx), + ) // adds join condition to the child table + .with_filters(rs.args.filter.clone(), None, ctx) // adds query filters + .comment("root"); + + // On MySQL, using LIMIT makes the ordering of the JSON_AGG working. Beware, this is undocumented behavior. + let take = match rs.args.order_by.is_empty() { + true => rs.args.take_abs(), + false => rs.args.take_abs().or(Some(i64::MAX)), + }; + + let inner = Select::from_table(Table::from(root).alias(root_alias.to_table_string())) + .value(self.build_json_obj_fn(rs, root_alias, ctx).alias(JSON_AGG_IDENT)) + .with_pagination(take, rs.args.skip) + .comment("inner"); // adds pagination + + Select::from_table(Table::from(inner).alias(outer_alias.to_table_string())) + .value(json_agg()) + .comment("outer") + } +} diff --git a/query-engine/core/src/query_graph_builder/read/utils.rs b/query-engine/core/src/query_graph_builder/read/utils.rs index 369cd312d4bd..19222baebf7d 100644 --- a/query-engine/core/src/query_graph_builder/read/utils.rs +++ b/query-engine/core/src/query_graph_builder/read/utils.rs @@ -1,6 +1,5 @@ use super::*; use crate::{ArgumentListLookup, FieldPair, ParsedField, ReadQuery}; -use psl::{datamodel_connector::ConnectorCapability, PreviewFeature}; use query_structure::{prelude::*, RelationLoadStrategy}; use schema::{ constants::{aggregations::*, args}, @@ -72,8 +71,7 @@ fn pairs_to_selections( where T: Into, { - let should_collect_relation_selection = query_schema.has_capability(ConnectorCapability::LateralJoin) - && query_schema.has_feature(PreviewFeature::RelationJoins); + let should_collect_relation_selection = query_schema.can_resolve_relation_with_joins(); let parent = parent.into(); @@ -259,8 +257,7 @@ pub(crate) fn get_relation_load_strategy( nested_queries: &[ReadQuery], query_schema: &QuerySchema, ) -> RelationLoadStrategy { - if query_schema.has_feature(PreviewFeature::RelationJoins) - && query_schema.has_capability(ConnectorCapability::LateralJoin) + if query_schema.can_resolve_relation_with_joins() && cursor.is_none() && distinct.is_none() && !nested_queries.iter().any(|q| match q { diff --git a/query-engine/query-structure/src/field/scalar.rs b/query-engine/query-structure/src/field/scalar.rs index 2e6474947227..becd438db276 100644 --- a/query-engine/query-structure/src/field/scalar.rs +++ b/query-engine/query-structure/src/field/scalar.rs @@ -155,15 +155,22 @@ impl ScalarField { } pub fn native_type(&self) -> Option { - let (_, name, args, span) = match self.id { + let connector = self.dm.schema.connector; + + let raw_nt = match self.id { ScalarFieldId::InModel(id) => self.dm.walk(id).raw_native_type(), ScalarFieldId::InCompositeType(id) => self.dm.walk(id).raw_native_type(), - }?; - let connector = self.dm.schema.connector; + }; + + let psl_nt = raw_nt + .and_then(|(_, name, args, span)| connector.parse_native_type(name, args, span, &mut Default::default())); - let nt = connector - .parse_native_type(name, args, span, &mut Default::default()) - .unwrap(); + let scalar_type = match self.id { + ScalarFieldId::InModel(id) => self.dm.walk(id).scalar_type(), + ScalarFieldId::InCompositeType(id) => self.dm.walk(id).scalar_type(), + }; + + let nt = psl_nt.or_else(|| scalar_type.and_then(|st| connector.default_native_type_for_scalar_type(&st)))?; Some(NativeTypeInstance { native_type: nt, @@ -178,6 +185,13 @@ impl ScalarField { connector.parse_json_datetime(value, nt) } + pub fn parse_json_bytes(&self, value: &str) -> PrismaValueResult> { + let nt = self.native_type().map(|nt| nt.native_type); + let connector = self.dm.schema.connector; + + connector.parse_json_bytes(value, nt) + } + pub fn is_autoincrement(&self) -> bool { match self.id { ScalarFieldId::InModel(id) => self.dm.walk(id).is_autoincrement(), diff --git a/query-engine/query-structure/src/field_selection.rs b/query-engine/query-structure/src/field_selection.rs index f2b1fccd9c5b..4558eb77f335 100644 --- a/query-engine/query-structure/src/field_selection.rs +++ b/query-engine/query-structure/src/field_selection.rs @@ -326,6 +326,12 @@ impl VirtualSelection { Self::RelationCount(_, _) => (TypeIdentifier::Int, FieldArity::Required), } } + + pub fn relation_field(&self) -> &RelationField { + match self { + VirtualSelection::RelationCount(rf, _) => rf, + } + } } impl Display for VirtualSelection { diff --git a/query-engine/schema/src/build/enum_types.rs b/query-engine/schema/src/build/enum_types.rs index c723cfbe587b..c878226e76bf 100644 --- a/query-engine/schema/src/build/enum_types.rs +++ b/query-engine/schema/src/build/enum_types.rs @@ -117,7 +117,7 @@ pub(crate) fn relation_load_strategy(ctx: &QuerySchema) -> Option { let ident = Identifier::new_prisma(IdentifierType::RelationLoadStrategy); - let values = if ctx.has_capability(ConnectorCapability::LateralJoin) { + let values = if ctx.can_resolve_relation_with_joins() { vec![load_strategy::QUERY.to_owned(), load_strategy::JOIN.to_owned()] } else { vec![load_strategy::QUERY.to_owned()] diff --git a/query-engine/schema/src/query_schema.rs b/query-engine/schema/src/query_schema.rs index 3098a96f1597..4859984d11a6 100644 --- a/query-engine/schema/src/query_schema.rs +++ b/query-engine/schema/src/query_schema.rs @@ -96,6 +96,12 @@ impl QuerySchema { || self.has_capability(ConnectorCapability::FullTextSearchWithIndex)) } + pub fn can_resolve_relation_with_joins(&self) -> bool { + self.has_feature(PreviewFeature::RelationJoins) + && (self.has_capability(ConnectorCapability::LateralJoin) + || self.has_capability(ConnectorCapability::CorrelatedSubqueries)) + } + pub fn has_feature(&self, feature: PreviewFeature) -> bool { self.preview_features.contains(feature) } diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs index 3516d0136045..3b36829cfcf0 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs @@ -465,7 +465,7 @@ fn push_column_for_builtin_scalar_type( let native_type = field .native_type_instance(connector) - .unwrap_or_else(|| connector.default_native_type_for_scalar_type(&scalar_type)); + .or_else(|| connector.default_native_type_for_scalar_type(&scalar_type)); enum ColumnDefault { Available(sql::DefaultValue), @@ -521,7 +521,7 @@ fn push_column_for_builtin_scalar_type( family, full_data_type: String::new(), arity: column_arity(field.ast_field().arity), - native_type: Some(native_type), + native_type, }, auto_increment: field.is_autoincrement() || ctx.flavour.field_is_implicit_autoincrement_primary_key(field), description: None,