diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs index d42dc627bf62..153354c518c8 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/coerce.rs @@ -5,18 +5,26 @@ use std::{io, str::FromStr}; use crate::{query_arguments_ext::QueryArgumentsExt, SqlError}; +pub(crate) enum IndexedSelection<'a> { + Relation(&'a RelationSelection), + Virtual(&'a str), +} + /// Coerces relations resolved as JSON to PrismaValues. /// Note: Some in-memory processing is baked into this function too for performance reasons. pub(crate) fn coerce_record_with_json_relation( record: &mut Record, - rs_indexes: Vec<(usize, &RelationSelection)>, + indexes: &[(usize, IndexedSelection<'_>)], ) -> crate::Result<()> { - for (val_idx, rs) in rs_indexes { - let val = record.values.get_mut(val_idx).unwrap(); + for (val_idx, kind) in indexes { + let val = record.values.get_mut(*val_idx).unwrap(); // TODO(perf): Find ways to avoid serializing and deserializing multiple times. let json_val: serde_json::Value = serde_json::from_str(val.as_json().unwrap()).unwrap(); - *val = coerce_json_relation_to_pv(json_val, rs)?; + *val = match kind { + IndexedSelection::Relation(rs) => coerce_json_relation_to_pv(json_val, rs)?, + IndexedSelection::Virtual(name) => coerce_json_virtual_field_to_pv(name, json_val)?, + }; } Ok(()) @@ -57,7 +65,7 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection) .map(|value| coerce_json_relation_to_pv(value, rs)); // TODO(HACK): We probably want to update the sql builder instead to not aggregate to-one relations as array - // If the arary is empty, it means there's no relations, so we coerce it to + // If the array is empty, it means there's no relations, so we coerce it to if let Some(val) = coerced { val } else { @@ -69,16 +77,20 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection) let related_model = rs.field.related_model(); for (key, value) in obj { - match related_model.fields().all().find(|f| f.db_name() == key).unwrap() { - Field::Scalar(sf) => { + match related_model.fields().all().find(|f| f.db_name() == key) { + Some(Field::Scalar(sf)) => { map.push((key, coerce_json_scalar_to_pv(value, &sf)?)); } - Field::Relation(rf) => { + Some(Field::Relation(rf)) => { // TODO: optimize this if let Some(nested_selection) = relations.iter().find(|rs| rs.field == rf) { map.push((key, coerce_json_relation_to_pv(value, nested_selection)?)); } } + None => { + let coerced_value = coerce_json_virtual_field_to_pv(&key, value)?; + map.push((key, coerced_value)); + } _ => (), } } @@ -191,27 +203,51 @@ pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarFiel } } +fn coerce_json_virtual_field_to_pv(key: &str, value: serde_json::Value) -> crate::Result { + match value { + serde_json::Value::Object(obj) => { + let values: crate::Result> = obj + .into_iter() + .map(|(key, value)| coerce_json_virtual_field_to_pv(&key, value).map(|value| (key, value))) + .collect(); + Ok(PrismaValue::Object(values?)) + } + + serde_json::Value::Number(num) => num + .as_i64() + .ok_or_else(|| { + build_generic_conversion_error(format!( + "Unexpected numeric value {num} for virtual field '{key}': only integers are supported" + )) + }) + .map(PrismaValue::Int), + + _ => Err(build_generic_conversion_error(format!( + "Field '{key}' is not a model field and doesn't have a supported type for a virtual field" + ))), + } +} + fn build_conversion_error(sf: &ScalarField, from: &str, to: &str) -> SqlError { let container_name = sf.container().name(); let field_name = sf.name(); - let error = io::Error::new( - io::ErrorKind::InvalidData, - format!("Unexpected conversion failure for field {container_name}.{field_name} from {from} to {to}."), - ); - - SqlError::ConversionError(error.into()) + build_generic_conversion_error(format!( + "Unexpected conversion failure for field {container_name}.{field_name} from {from} to {to}." + )) } fn build_conversion_error_with_reason(sf: &ScalarField, from: &str, to: &str, reason: &str) -> SqlError { let container_name = sf.container().name(); let field_name = sf.name(); - let error = io::Error::new( - io::ErrorKind::InvalidData, - format!("Unexpected conversion failure for field {container_name}.{field_name} from {from} to {to}. Reason: ${reason}"), - ); + build_generic_conversion_error(format!( + "Unexpected conversion failure for field {container_name}.{field_name} from {from} to {to}. Reason: {reason}" + )) +} +fn build_generic_conversion_error(message: String) -> SqlError { + let error = io::Error::new(io::ErrorKind::InvalidData, message); SqlError::ConversionError(error.into()) } diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 32e9dba67b79..07287fee303c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -1,4 +1,4 @@ -use super::coerce::coerce_record_with_json_relation; +use super::coerce::{coerce_record_with_json_relation, IndexedSelection}; use crate::{ column_metadata, model_extensions::*, @@ -33,9 +33,14 @@ pub(crate) async fn get_single_record_joins( selected_fields: &FieldSelection, ctx: &Context<'_>, ) -> crate::Result> { - let field_names: Vec<_> = selected_fields.db_names().collect(); - let idents = selected_fields.type_identifiers_with_arities(); - let rs_indexes = get_relation_selection_indexes(selected_fields.relations().collect(), &field_names); + let field_names: Vec<_> = selected_fields.db_names_grouping_virtuals().collect(); + let idents = selected_fields.type_identifiers_with_arities_grouping_virtuals(); + + let indexes = get_selection_indexes( + selected_fields.relations().collect(), + selected_fields.virtuals().collect(), + &field_names, + ); let query = query_builder::select::SelectBuilder::default().build( QueryArguments::from((model.clone(), filter.clone())), @@ -46,7 +51,7 @@ pub(crate) async fn get_single_record_joins( let mut record = execute_find_one(conn, query, &idents, &field_names, ctx).await?; if let Some(record) = record.as_mut() { - coerce_record_with_json_relation(record, rs_indexes)?; + coerce_record_with_json_relation(record, &indexes)?; }; Ok(record.map(|record| SingleRecord { record, field_names })) @@ -125,10 +130,15 @@ pub(crate) async fn get_many_records_joins( selected_fields: &FieldSelection, ctx: &Context<'_>, ) -> crate::Result { - let field_names: Vec<_> = selected_fields.db_names().collect(); - let idents = selected_fields.type_identifiers_with_arities(); + let field_names: Vec<_> = selected_fields.db_names_grouping_virtuals().collect(); + let idents = selected_fields.type_identifiers_with_arities_grouping_virtuals(); let meta = column_metadata::create(field_names.as_slice(), idents.as_slice()); - let rs_indexes = get_relation_selection_indexes(selected_fields.relations().collect(), &field_names); + + let indexes = get_selection_indexes( + selected_fields.relations().collect(), + selected_fields.virtuals().collect(), + &field_names, + ); let mut records = ManyRecords::new(field_names.clone()); @@ -151,7 +161,7 @@ pub(crate) async fn get_many_records_joins( let mut record = Record::from(item); // Coerces json values to prisma values - coerce_record_with_json_relation(&mut record, rs_indexes.clone())?; + coerce_record_with_json_relation(&mut record, &indexes)?; records.push(record) } @@ -395,18 +405,27 @@ async fn group_by_aggregate( .collect()) } -/// Find the indexes of the relation records to traverse a set of records faster when coercing JSON values -fn get_relation_selection_indexes<'a>( - selections: Vec<&'a RelationSelection>, - field_names: &[String], -) -> Vec<(usize, &'a RelationSelection)> { - let mut output: Vec<(usize, &RelationSelection)> = Vec::new(); - - for (idx, field_name) in field_names.iter().enumerate() { - if let Some(rs) = selections.iter().find(|rq| rq.field.name() == *field_name) { - output.push((idx, rs)); - } - } - - output +/// Find the indexes of the relation records and the virtual selection objects to traverse a set of +/// records faster when coercing JSON values. +fn get_selection_indexes<'a>( + relations: Vec<&'a RelationSelection>, + virtuals: Vec<&'a VirtualSelection>, + field_names: &'a [String], +) -> Vec<(usize, IndexedSelection<'a>)> { + field_names + .iter() + .enumerate() + .filter_map(|(idx, field_name)| { + relations + .iter() + .find_map(|rs| (rs.field.name() == field_name).then_some(IndexedSelection::Relation(rs))) + .or_else(|| { + virtuals.iter().find_map(|vs| { + let obj_name = vs.serialized_name().0; + (obj_name == field_name).then_some(IndexedSelection::Virtual(obj_name)) + }) + }) + .map(|indexed_selection| (idx, indexed_selection)) + }) + .collect() } diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select.rs index d6f30fe413f8..27a4795789e4 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::{borrow::Cow, collections::BTreeMap}; use tracing::Span; use crate::{ @@ -44,7 +44,11 @@ impl SelectBuilder { .add_trace_id(ctx.trace_id); // Adds joins for relations - self.with_related_queries(select, selected_fields.relations(), table_alias, ctx) + let select = self.with_related_queries(select, selected_fields.relations(), table_alias, ctx); + + // Adds joins for relation aggregations. Other potential future kinds of virtual fields + // might or might not require joins and might be processed differently. + self.with_relation_aggregation_queries(select, selected_fields.virtuals(), table_alias, ctx) } fn with_related_queries<'a, 'b>( @@ -107,6 +111,9 @@ impl SelectBuilder { // LEFT JOIN LATERAL () AS ON TRUE let inner = self.with_related_queries(inner, rs.relations(), root_alias, ctx); + // LEFT JOIN LATERAL ( ) ON TRUE + let inner = self.with_relation_aggregation_queries(inner, rs.virtuals(), root_alias, ctx); + let linking_fields = rs.field.related_field().linking_fields(); if rs.field.relation().is_many_to_many() { @@ -160,20 +167,8 @@ impl SelectBuilder { let left_columns = rf.related_field().m2m_columns(ctx); let right_columns = ModelProjection::from(rf.model().primary_identifier()).as_columns(ctx); - let join_conditions = left_columns - .into_iter() - .zip(right_columns) - .fold(None::, |acc, (a, b)| { - let a = a.table(m2m_table_alias.to_table_string()); - let b = b.table(parent_alias.to_table_string()); - let condition = a.equals(b); - - match acc { - Some(acc) => Some(acc.and(condition)), - None => Some(condition.into()), - } - }) - .unwrap(); + let join_conditions = + build_join_conditions((left_columns.into(), m2m_table_alias), (right_columns, parent_alias)); let m2m_join_data = Table::from(self.build_related_query_select(rs, m2m_table_alias, ctx)) .alias(m2m_join_alias.to_table_string()) @@ -200,6 +195,107 @@ impl SelectBuilder { .on(ConditionTree::single(true.raw())) .lateral() } + + fn with_relation_aggregation_queries<'a, 'b>( + &mut self, + select: Select<'a>, + selections: impl Iterator, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + selections.fold(select, |acc, vs| { + self.with_relation_aggregation_query(acc, vs, parent_alias, ctx) + }) + } + + fn with_relation_aggregation_query<'a>( + &mut self, + select: Select<'a>, + vs: &VirtualSelection, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + match vs { + VirtualSelection::RelationCount(rf, filter) => { + let table_alias = relation_count_alias_name(rf); + + let relation_count_select = if rf.relation().is_many_to_many() { + self.build_relation_count_query_m2m(vs.db_alias(), rf, filter, parent_alias, ctx) + } else { + self.build_relation_count_query(vs.db_alias(), rf, filter, parent_alias, ctx) + }; + + let table = Table::from(relation_count_select).alias(table_alias); + + select.left_join_lateral(table.on(ConditionTree::single(true.raw()))) + } + } + } + + fn build_relation_count_query<'a>( + &mut self, + selection_name: impl Into>, + rf: &RelationField, + filter: &Option, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let related_table_alias = self.next_alias(); + + let related_table = rf + .related_model() + .as_table(ctx) + .alias(related_table_alias.to_table_string()); + + let select = Select::from_table(related_table) + .value(count(asterisk()).alias(selection_name)) + .with_join_conditions(rf, parent_alias, related_table_alias, ctx) + .with_filters(filter.clone(), Some(related_table_alias), ctx); + + select + } + + fn build_relation_count_query_m2m<'a>( + &mut self, + selection_name: impl Into>, + rf: &RelationField, + filter: &Option, + parent_alias: Alias, + ctx: &Context<'_>, + ) -> Select<'a> { + let related_table_alias = self.next_alias(); + let m2m_table_alias = self.next_alias(); + + let related_table = rf + .related_model() + .as_table(ctx) + .alias(related_table_alias.to_table_string()); + + let m2m_join_conditions = { + let left_columns = rf.join_columns(ctx); + let right_columns = ModelProjection::from(rf.related_field().linking_fields()).as_columns(ctx); + build_join_conditions((left_columns, m2m_table_alias), (right_columns, related_table_alias)) + }; + + let m2m_join_data = rf + .as_table(ctx) + .alias(m2m_table_alias.to_table_string()) + .on(m2m_join_conditions); + + let aggregation_join_conditions = { + let left_columns = rf.related_field().m2m_columns(ctx); + let right_columns = ModelProjection::from(rf.model().primary_identifier()).as_columns(ctx); + build_join_conditions((left_columns.into(), m2m_table_alias), (right_columns, parent_alias)) + }; + + let select = Select::from_table(related_table) + .value(count(asterisk()).alias(selection_name)) + .left_join(m2m_join_data) + .and_where(aggregation_join_conditions) + .with_filters(filter.clone(), Some(related_table_alias), ctx); + + select + } } trait SelectBuilderExt<'a> { @@ -214,6 +310,7 @@ trait SelectBuilderExt<'a> { ctx: &Context<'_>, ) -> Select<'a>; fn with_selection(self, selected_fields: &FieldSelection, table_alias: Alias, ctx: &Context<'_>) -> Select<'a>; + fn with_virtuals_from_selection(self, selected_fields: &FieldSelection) -> Select<'a>; fn with_columns(self, columns: ColumnIterator) -> Select<'a>; } @@ -274,21 +371,9 @@ impl<'a> SelectBuilderExt<'a> for Select<'a> { let join_columns = rf.join_columns(ctx); let related_join_columns = ModelProjection::from(rf.related_field().linking_fields()).as_columns(ctx); - // WHERE Parent.id = Child.id - let conditions = join_columns - .zip(related_join_columns) - .fold(None::, |acc, (a, b)| { - let a = a.table(parent_alias.to_table_string()); - let b = b.table(child_alias.to_table_string()); - let condition = a.equals(b); - - match acc { - Some(acc) => Some(acc.and(condition)), - None => Some(condition.into()), - } - }) - .unwrap(); + let conditions = build_join_conditions((join_columns, parent_alias), (related_join_columns, child_alias)); + // WHERE Parent.id = Child.id self.and_where(conditions) } @@ -311,6 +396,13 @@ impl<'a> SelectBuilderExt<'a> for Select<'a> { } _ => acc, }) + .with_virtuals_from_selection(selected_fields) + } + + fn with_virtuals_from_selection(self, selected_fields: &FieldSelection) -> Select<'a> { + build_virtual_selection(selected_fields.virtuals()) + .into_iter() + .fold(self, |select, (alias, expr)| select.value(expr.alias(alias))) } fn with_columns(self, columns: ColumnIterator) -> Select<'a> { @@ -318,6 +410,25 @@ impl<'a> SelectBuilderExt<'a> for Select<'a> { } } +fn build_join_conditions( + (left_columns, left_alias): (ColumnIterator, Alias), + (right_columns, right_alias): (ColumnIterator, Alias), +) -> ConditionTree<'static> { + left_columns + .zip(right_columns) + .fold(None::, |acc, (a, b)| { + let a = a.table(left_alias.to_table_string()); + let b = b.table(right_alias.to_table_string()); + let condition = a.equals(b); + + match acc { + Some(acc) => Some(acc.and(condition)), + None => Some(condition.into()), + } + }) + .unwrap() +} + fn build_json_obj_fn(rs: &RelationSelection, ctx: &Context<'_>, root_alias: Alias) -> Function<'static> { let build_obj_params = rs .selections @@ -340,6 +451,7 @@ fn build_json_obj_fn(rs: &RelationSelection, ctx: &Context<'_>, root_alias: Alia } _ => None, }) + .chain(build_virtual_selection(rs.virtuals())) .collect(); json_build_object(build_obj_params) @@ -410,3 +522,36 @@ fn json_agg() -> Function<'static> { ]) .alias(JSON_AGG_IDENT) } + +fn build_virtual_selection<'a>( + virtual_fields: impl Iterator, +) -> Vec<(Cow<'static, str>, Expression<'static>)> { + let mut selected_objects = BTreeMap::new(); + + for vs in virtual_fields { + match vs { + VirtualSelection::RelationCount(rf, _) => { + let (object_name, field_name) = vs.serialized_name(); + + let coalesce_args: Vec> = vec![ + Column::from((relation_count_alias_name(rf), vs.db_alias())).into(), + 0.raw().into(), + ]; + + selected_objects + .entry(object_name) + .or_insert(Vec::new()) + .push((field_name.to_owned().into(), coalesce(coalesce_args).into())); + } + } + } + + selected_objects + .into_iter() + .map(|(name, fields)| (name.into(), json_build_object(fields).into())) + .collect() +} + +fn relation_count_alias_name(rf: &RelationField) -> String { + format!("aggr_count_{}_{}", rf.model().name(), rf.name()) +} diff --git a/query-engine/core/src/interpreter/query_interpreters/read.rs b/query-engine/core/src/interpreter/query_interpreters/read.rs index 883246dc84e2..89747d33dbe3 100644 --- a/query-engine/core/src/interpreter/query_interpreters/read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/read.rs @@ -64,6 +64,7 @@ fn read_one( name: query.name, model, fields: query.selection_order, + virtuals: query.selected_fields.virtuals_owned(), records, nested: build_relation_record_selection(query.selected_fields.relations()), } @@ -172,6 +173,7 @@ fn read_many_by_joins( Ok(RecordSelectionWithRelations { name: query.name, fields: query.selection_order, + virtuals: query.selected_fields.virtuals_owned(), records: result, nested: build_relation_record_selection(query.selected_fields.relations()), model: query.model, @@ -190,6 +192,7 @@ fn build_relation_record_selection<'a>( .map(|rq| RelationRecordSelection { name: rq.field.name().to_owned(), fields: rq.result_fields.clone(), + virtuals: rq.virtuals().cloned().collect(), model: rq.field.related_model(), nested: build_relation_record_selection(rq.relations()), }) diff --git a/query-engine/core/src/query_ast/read.rs b/query-engine/core/src/query_ast/read.rs index 64a2440c0f52..28f2f8383649 100644 --- a/query-engine/core/src/query_ast/read.rs +++ b/query-engine/core/src/query_ast/read.rs @@ -73,19 +73,6 @@ impl ReadQuery { ReadQuery::AggregateRecordsQuery(_) => false, } } - - pub(crate) fn has_virtual_selections(&self) -> bool { - fn has_virtuals(selection: &FieldSelection, nested: &[ReadQuery]) -> bool { - selection.has_virtual_fields() || nested.iter().any(|q| q.has_virtual_selections()) - } - - match self { - ReadQuery::RecordQuery(q) => has_virtuals(&q.selected_fields, &q.nested), - ReadQuery::ManyRecordsQuery(q) => has_virtuals(&q.selected_fields, &q.nested), - ReadQuery::RelatedRecordsQuery(q) => has_virtuals(&q.selected_fields, &q.nested), - ReadQuery::AggregateRecordsQuery(_) => false, - } - } } impl FilteredQuery for ReadQuery { @@ -243,10 +230,6 @@ impl RelatedRecordsQuery { pub fn has_distinct(&self) -> bool { self.args.distinct.is_some() || self.nested.iter().any(|q| q.has_distinct()) } - - pub fn has_virtual_selections(&self) -> bool { - self.selected_fields.has_virtual_fields() || self.nested.iter().any(|q| q.has_virtual_selections()) - } } #[derive(Debug, Clone)] diff --git a/query-engine/core/src/query_graph_builder/read/many.rs b/query-engine/core/src/query_graph_builder/read/many.rs index 29eb769f74d0..edadeb8814df 100644 --- a/query-engine/core/src/query_graph_builder/read/many.rs +++ b/query-engine/core/src/query_graph_builder/read/many.rs @@ -42,7 +42,6 @@ fn find_many_with_options( args.cursor.as_ref(), args.distinct.as_ref(), &nested, - &selected_fields, query_schema, ); diff --git a/query-engine/core/src/query_graph_builder/read/one.rs b/query-engine/core/src/query_graph_builder/read/one.rs index afc07ed0e89e..a2dd291f6760 100644 --- a/query-engine/core/src/query_graph_builder/read/one.rs +++ b/query-engine/core/src/query_graph_builder/read/one.rs @@ -50,14 +50,8 @@ fn find_unique_with_options( let nested = utils::collect_nested_queries(nested_fields, &model, query_schema)?; let selected_fields = utils::merge_relation_selections(selected_fields, None, &nested); - let relation_load_strategy = get_relation_load_strategy( - requested_rel_load_strategy, - None, - None, - &nested, - &selected_fields, - query_schema, - ); + let relation_load_strategy = + get_relation_load_strategy(requested_rel_load_strategy, None, None, &nested, query_schema); Ok(ReadQuery::RecordQuery(RecordQuery { name, diff --git a/query-engine/core/src/query_graph_builder/read/utils.rs b/query-engine/core/src/query_graph_builder/read/utils.rs index 69fe60d95f39..369cd312d4bd 100644 --- a/query-engine/core/src/query_graph_builder/read/utils.rs +++ b/query-engine/core/src/query_graph_builder/read/utils.rs @@ -257,16 +257,14 @@ pub(crate) fn get_relation_load_strategy( cursor: Option<&SelectionResult>, distinct: Option<&FieldSelection>, nested_queries: &[ReadQuery], - selected_fields: &FieldSelection, query_schema: &QuerySchema, ) -> RelationLoadStrategy { if query_schema.has_feature(PreviewFeature::RelationJoins) && query_schema.has_capability(ConnectorCapability::LateralJoin) && cursor.is_none() && distinct.is_none() - && !selected_fields.has_virtual_fields() && !nested_queries.iter().any(|q| match q { - ReadQuery::RelatedRecordsQuery(q) => q.has_cursor() || q.has_distinct() || q.has_virtual_selections(), + ReadQuery::RelatedRecordsQuery(q) => q.has_cursor() || q.has_distinct(), _ => false, }) && requested_strategy != Some(RelationLoadStrategy::Query) diff --git a/query-engine/core/src/response_ir/internal.rs b/query-engine/core/src/response_ir/internal.rs index 121525dad15c..c6cf4fdd74b6 100644 --- a/query-engine/core/src/response_ir/internal.rs +++ b/query-engine/core/src/response_ir/internal.rs @@ -7,7 +7,7 @@ use crate::{ }; use connector::AggregationResult; use indexmap::IndexMap; -use query_structure::{CompositeFieldRef, Field, PrismaValue, SelectionResult, VirtualSelection}; +use query_structure::{CompositeFieldRef, Field, Model, PrismaValue, SelectionResult, VirtualSelection}; use schema::{ constants::{aggregations::*, output_fields::*}, *, @@ -307,6 +307,27 @@ fn finalize_objects( } } +enum SerializedFieldWithRelations<'a, 'b> { + Model(Field, &'a OutputField<'b>), + VirtualsGroup(&'a str, Vec<&'a VirtualSelection>), +} + +impl<'a, 'b> SerializedFieldWithRelations<'a, 'b> { + fn name(&self) -> &str { + match self { + Self::Model(f, _) => f.name(), + Self::VirtualsGroup(name, _) => name, + } + } + + fn db_name(&self) -> &str { + match self { + Self::Model(f, _) => f.db_name(), + Self::VirtualsGroup(name, _) => name, + } + } +} + // TODO: Handle errors properly fn serialize_objects_with_relation( result: RecordSelectionWithRelations, @@ -314,14 +335,10 @@ fn serialize_objects_with_relation( ) -> crate::Result { let mut object_mapping = UncheckedItemsWithParents::with_capacity(result.records.records.len()); - let model = result.model; - let db_field_names = result.records.field_names; let nested = result.nested; - let fields: Vec<_> = db_field_names - .iter() - .filter_map(|f| model.fields().all().find(|field| field.db_name() == f)) - .collect(); + let fields = + collect_serialized_fields_with_relations(typ, &result.model, &result.virtuals, &result.records.field_names); // Hack: we convert it to a hashset to support contains with &str as input // because Vec::contains(&str) doesn't work and we don't want to allocate a string record value @@ -341,13 +358,16 @@ fn serialize_objects_with_relation( continue; } - let out_field = typ.find_field(field.name()).unwrap(); - match field { - Field::Scalar(_) if !out_field.field_type().is_object() => { + SerializedFieldWithRelations::Model(Field::Scalar(_), out_field) + if !out_field.field_type().is_object() => + { object.insert(field.name().to_owned(), serialize_scalar(out_field, val)?); } - Field::Relation(_) if out_field.field_type().is_list() => { + + SerializedFieldWithRelations::Model(Field::Relation(_), out_field) + if out_field.field_type().is_list() => + { let inner_typ = out_field.field_type.as_object_type().unwrap(); let rrs = nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); @@ -360,7 +380,8 @@ fn serialize_objects_with_relation( object.insert(field.name().to_owned(), Item::list(items)); } - Field::Relation(_) => { + + SerializedFieldWithRelations::Model(Field::Relation(_), out_field) => { let inner_typ = out_field.field_type.as_object_type().unwrap(); let rrs = nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); @@ -369,6 +390,11 @@ fn serialize_objects_with_relation( serialize_relation_selection(rrs, val, inner_typ)?, ); } + + SerializedFieldWithRelations::VirtualsGroup(group_name, virtuals) => { + object.insert(group_name.to_string(), serialize_virtuals_group(val, virtuals)?); + } + _ => panic!("unexpected field"), } } @@ -397,21 +423,18 @@ fn serialize_relation_selection( // TODO: better handle errors let mut value_obj: HashMap = HashMap::from_iter(value.into_object().unwrap()); - let db_field_names = &rrs.fields; - let fields: Vec<_> = db_field_names - .iter() - .filter_map(|f| rrs.model.fields().all().find(|field| field.name() == f)) - .collect(); + + let fields = collect_serialized_fields_with_relations(typ, &rrs.model, &rrs.virtuals, &rrs.fields); for field in fields { - let out_field = typ.find_field(field.name()).unwrap(); let value = value_obj.remove(field.db_name()).unwrap(); match field { - Field::Scalar(_) if !out_field.field_type().is_object() => { + SerializedFieldWithRelations::Model(Field::Scalar(_), out_field) if !out_field.field_type().is_object() => { map.insert(field.name().to_owned(), serialize_scalar(out_field, value)?); } - Field::Relation(_) if out_field.field_type().is_list() => { + + SerializedFieldWithRelations::Model(Field::Relation(_), out_field) if out_field.field_type().is_list() => { let inner_typ = out_field.field_type.as_object_type().unwrap(); let inner_rrs = rrs.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); @@ -424,7 +447,8 @@ fn serialize_relation_selection( map.insert(field.name().to_owned(), Item::list(items)); } - Field::Relation(_) => { + + SerializedFieldWithRelations::Model(Field::Relation(_), out_field) => { let inner_typ = out_field.field_type.as_object_type().unwrap(); let inner_rrs = rrs.nested.iter().find(|rrs| rrs.name == field.name()).unwrap(); @@ -433,6 +457,11 @@ fn serialize_relation_selection( serialize_relation_selection(inner_rrs, value, inner_typ)?, ); } + + SerializedFieldWithRelations::VirtualsGroup(group_name, virtuals) => { + map.insert(group_name.to_string(), serialize_virtuals_group(value, &virtuals)?); + } + _ => (), } } @@ -440,6 +469,53 @@ fn serialize_relation_selection( Ok(Item::Map(map)) } +fn collect_serialized_fields_with_relations<'a, 'b>( + object_type: &'a ObjectType<'b>, + model: &Model, + virtuals: &'a [VirtualSelection], + db_field_names: &'a [String], +) -> Vec> { + db_field_names + .iter() + .map(|name| { + model + .fields() + .all() + .find(|field| field.db_name() == name) + .and_then(|field| { + object_type + .find_field(field.name()) + .map(|out_field| SerializedFieldWithRelations::Model(field, out_field)) + }) + .unwrap_or_else(|| { + let matching_virtuals = virtuals.iter().filter(|vs| vs.serialized_name().0 == name).collect(); + SerializedFieldWithRelations::VirtualsGroup(name.as_str(), matching_virtuals) + }) + }) + .collect() +} + +fn serialize_virtuals_group(obj_value: PrismaValue, virtuals: &[&VirtualSelection]) -> crate::Result { + let mut db_object: HashMap = HashMap::from_iter(obj_value.into_object().unwrap()); + let mut out_object = Map::new(); + + // We have to reorder the object fields according to selection even if the query + // builder respects the initial order because JSONB does not preserve order. + for vs in virtuals { + let (group_name, nested_name) = vs.serialized_name(); + + let value = db_object.remove(nested_name).ok_or_else(|| { + CoreError::SerializationError(format!( + "Expected virtual field {nested_name} not found in {group_name} object" + )) + })?; + + out_object.insert(nested_name.into(), Item::Value(vs.coerce_value(value)?)); + } + + Ok(Item::Map(out_object)) +} + enum SerializedField<'a, 'b> { Model(Field, &'a OutputField<'b>), Virtual(&'a VirtualSelection), diff --git a/query-engine/core/src/result_ast/mod.rs b/query-engine/core/src/result_ast/mod.rs index e86c39ddf392..e450b7213774 100644 --- a/query-engine/core/src/result_ast/mod.rs +++ b/query-engine/core/src/result_ast/mod.rs @@ -20,6 +20,11 @@ pub struct RecordSelectionWithRelations { /// Holds an ordered list of selected field names for each contained record. pub(crate) fields: Vec, + /// Holds the list of virtual selections included in the query result. + /// TODO: in the future it should be covered by [`RecordSelection::fields`] by storing ordered + /// `Vec` or `FieldSelection` instead of `Vec`. + pub(crate) virtuals: Vec, + /// Selection results pub(crate) records: ManyRecords, @@ -41,6 +46,8 @@ pub struct RelationRecordSelection { pub name: String, /// Holds an ordered list of selected field names for each contained record. pub fields: Vec, + /// Holds the list of virtual selections included in the query result. + pub virtuals: Vec, /// The model of the contained records. pub model: Model, /// Nested relation selections diff --git a/query-engine/query-structure/src/field_selection.rs b/query-engine/query-structure/src/field_selection.rs index 20b037f4e571..f2b1fccd9c5b 100644 --- a/query-engine/query-structure/src/field_selection.rs +++ b/query-engine/query-structure/src/field_selection.rs @@ -44,10 +44,7 @@ impl FieldSelection { } pub fn virtuals(&self) -> impl Iterator { - self.selections().filter_map(|field| match field { - SelectedField::Virtual(ref vs) => Some(vs), - _ => None, - }) + self.selections().filter_map(SelectedField::as_virtual) } pub fn virtuals_owned(&self) -> Vec { @@ -71,16 +68,40 @@ impl FieldSelection { FieldSelection::new(non_virtuals.into_iter().chain(virtuals).collect()) } + /// Returns the selections, grouping the virtual fields that are wrapped into objects in the + /// query (like `_count`) and returning only the first virtual field in each of those groups. + /// This is useful when we want to treat the group as a whole but we don't need the information + /// about every field in the group and can infer the necessary information (like the group + /// name) from any of those fields. This method is used by + /// [`FieldSelection::db_names_grouping_virtuals`] and + /// [`FieldSelection::type_identifiers_with_arities_grouping_virtuals`]. + fn selections_with_virtual_group_heads(&self) -> impl Iterator { + self.selections().unique_by(|f| f.db_name_grouping_virtuals()) + } + /// Returns all Prisma (e.g. schema model field) names of contained fields. /// Does _not_ recurse into composite selections and only iterates top level fields. pub fn prisma_names(&self) -> impl Iterator + '_ { - self.selections.iter().map(|f| f.prisma_name().into_owned()) + self.selections().map(|f| f.prisma_name().into_owned()) } /// Returns all database (e.g. column or document field) names of contained fields. - /// Does _not_ recurse into composite selections and only iterates level fields. + /// Does _not_ recurse into composite selections and only iterates top level fields. + /// Returns db aliases for virtual fields grouped into objects in the query separately, + /// representing results of queries that do not load relations using JOINs. pub fn db_names(&self) -> impl Iterator + '_ { - self.selections.iter().map(|f| f.db_name().into_owned()) + self.selections().map(|f| f.db_name().into_owned()) + } + + /// Returns all database (e.g. column or document field) names of contained fields. Does not + /// recurse into composite selections and only iterates top level fields. Also does not recurse + /// into the grouped containers for virtual fields, like `_count`. The names returned by this + /// method correspond to the results of queries that use JSON objects to represent joined + /// relations and relation aggregations. + pub fn db_names_grouping_virtuals(&self) -> impl Iterator + '_ { + self.selections_with_virtual_group_heads() + .map(|f| f.db_name_grouping_virtuals()) + .map(Cow::into_owned) } /// Checked if a field of prisma name `name` is present in this `FieldSelection`. @@ -182,15 +203,20 @@ impl FieldSelection { *self = this.merge(other); } + /// Returns type identifiers and arities, treating all virtual fields as separate fields. pub fn type_identifiers_with_arities(&self) -> Vec<(TypeIdentifier, FieldArity)> { self.selections() - .filter_map(|selection| match selection { - SelectedField::Scalar(sf) => Some(sf.type_identifier_with_arity()), - SelectedField::Relation(rf) if rf.field.is_list() => Some((TypeIdentifier::Json, FieldArity::Required)), - SelectedField::Relation(rf) => Some((TypeIdentifier::Json, rf.field.arity())), - SelectedField::Composite(_) => None, - SelectedField::Virtual(vs) => Some(vs.type_identifier_with_arity()), - }) + .filter_map(SelectedField::type_identifier_with_arity) + .collect() + } + + /// Returns type identifiers and arities, grouping the virtual fields so that the type + /// identifier and arity is returned for the whole object containing multiple virtual fields + /// and not each of those fields separately. This represents the selection in joined queries + /// that use JSON objects for relations and relation aggregations. + pub fn type_identifiers_with_arities_grouping_virtuals(&self) -> Vec<(TypeIdentifier, FieldArity)> { + self.selections_with_virtual_group_heads() + .filter_map(|vs| vs.type_identifier_with_arity_grouping_virtuals()) .collect() } @@ -245,6 +271,10 @@ impl RelationSelection { }) } + pub fn virtuals(&self) -> impl Iterator { + self.selections.iter().filter_map(SelectedField::as_virtual) + } + pub fn related_model(&self) -> Model { self.field.related_model() } @@ -314,6 +344,9 @@ impl SelectedField { } } + /// Returns the name of the field in the database (if applicable) or other kind of name that is + /// used in the queries for this field. For virtual fields, this returns the alias used in the + /// queries that do not group them into objects. pub fn db_name(&self) -> Cow<'_, str> { match self { SelectedField::Scalar(sf) => sf.db_name().into(), @@ -323,6 +356,49 @@ impl SelectedField { } } + /// Returns the name of the field in the database (if applicable) or other kind of name that is + /// used in the queries for this field. For virtual fields that are wrapped inside an object in + /// Prisma queries, this returns the name of the surrounding object and not the field itself, + /// so this method can return identical values for multiple fields in the [`FieldSelection`]. + /// This is used in queries with relation JOINs which use JSON objects to represent both + /// relations and relation aggregations. For those queries, the result of this method + /// corresponds to the top-level name of the value which is a JSON object that contains this + /// field inside. + pub fn db_name_grouping_virtuals(&self) -> Cow<'_, str> { + match self { + SelectedField::Virtual(vs) => vs.serialized_name().0.into(), + _ => self.db_name(), + } + } + + /// Returns the type identifier and arity of this field, unless it is a composite field, in + /// which case [`None`] is returned. + pub fn type_identifier_with_arity(&self) -> Option<(TypeIdentifier, FieldArity)> { + match self { + SelectedField::Scalar(sf) => Some(sf.type_identifier_with_arity()), + SelectedField::Relation(rf) if rf.field.is_list() => Some((TypeIdentifier::Json, FieldArity::Required)), + SelectedField::Relation(rf) => Some((TypeIdentifier::Json, rf.field.arity())), + SelectedField::Composite(_) => None, + SelectedField::Virtual(vs) => Some(vs.type_identifier_with_arity()), + } + } + + /// Returns the type identifier and arity of this field, unless it is a composite field, in + /// which case [`None`] is returned. + /// + /// In the case of virtual fields that are wrapped into objects in Prisma queries + /// (specifically, relation aggregations), the returned information refers not to the current + /// field itself but to the whole object that contains this field. This is used by the queries + /// with relation JOINs because they use JSON objects to reprsent both relations and relation + /// aggregations, so individual virtual fields that correspond to those relation aggregations + /// don't exist as separate values in the result of the query. + pub fn type_identifier_with_arity_grouping_virtuals(&self) -> Option<(TypeIdentifier, FieldArity)> { + match self { + SelectedField::Virtual(_) => Some((TypeIdentifier::Json, FieldArity::Required)), + _ => self.type_identifier_with_arity(), + } + } + pub fn as_composite(&self) -> Option<&CompositeSelection> { match self { SelectedField::Composite(ref cs) => Some(cs), @@ -330,6 +406,13 @@ impl SelectedField { } } + pub fn as_virtual(&self) -> Option<&VirtualSelection> { + match self { + SelectedField::Virtual(vs) => Some(vs), + _ => None, + } + } + pub fn container(&self) -> ParentContainer { match self { SelectedField::Scalar(sf) => sf.container(),