From bf6d2efa330f2580f1557076ba7eeaafbacb5a0e Mon Sep 17 00:00:00 2001 From: Scott Donnelly Date: Mon, 9 Sep 2024 08:34:29 +0100 Subject: [PATCH] feat: RecordBatchEvolutionProcessor handles skipped fields in projection --- crates/iceberg/src/arrow/reader.rs | 19 +- .../arrow/record_batch_evolution_processor.rs | 426 ++++++++++-------- crates/iceberg/src/scan.rs | 27 ++ 3 files changed, 274 insertions(+), 198 deletions(-) diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index 457cb09d2..35d51c935 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -197,11 +197,8 @@ impl ArrowReader { // create a RecordBatchEvolutionProcessor if our task schema contains columns // not present in the parquet file or whose types have been promoted - let record_batch_evolution_processor = RecordBatchEvolutionProcessor::build( - record_batch_stream_builder.schema(), - task.schema(), - task.project_field_ids(), - )?; + let mut record_batch_evolution_processor = + RecordBatchEvolutionProcessor::build(task.schema_ref(), task.project_field_ids()); if let Some(batch_size) = batch_size { record_batch_stream_builder = record_batch_stream_builder.with_batch_size(batch_size); @@ -243,15 +240,9 @@ impl ArrowReader { // to the requester. let mut record_batch_stream = record_batch_stream_builder.build()?; - if let Some(record_batch_evolution_processor) = record_batch_evolution_processor { - while let Some(batch) = record_batch_stream.try_next().await? { - tx.send(record_batch_evolution_processor.process_record_batch(batch)) - .await? - } - } else { - while let Some(batch) = record_batch_stream.try_next().await? { - tx.send(Ok(batch)).await? - } + while let Some(batch) = record_batch_stream.try_next().await? { + tx.send(record_batch_evolution_processor.process_record_batch(batch)) + .await? } Ok(()) diff --git a/crates/iceberg/src/arrow/record_batch_evolution_processor.rs b/crates/iceberg/src/arrow/record_batch_evolution_processor.rs index 9b3d20098..9e4106c14 100644 --- a/crates/iceberg/src/arrow/record_batch_evolution_processor.rs +++ b/crates/iceberg/src/arrow/record_batch_evolution_processor.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use arrow::compute::cast; @@ -5,7 +6,9 @@ use arrow_array::{ Array as ArrowArray, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, NullArray, RecordBatch, StringArray, }; -use arrow_schema::{DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ + DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, SchemaRef, +}; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use crate::arrow::schema_to_arrow_schema; @@ -16,19 +19,19 @@ use crate::{Error, ErrorKind, Result}; /// to transform a RecordBatch coming from a Parquet file record /// batch stream so that it conforms to an Iceberg schema that has /// evolved from the one that was used when the file was written. -#[derive(Debug)] -pub(crate) struct EvolutionOp { - index: usize, - action: EvolutionAction, -} - #[derive(Debug)] pub(crate) enum EvolutionAction { - // signifies that a particular column has undergone type promotion, - // thus the column with the given index needs to be promoted to the + // signifies that a column should be passed through unmodified + PassThrough { + source_index: usize, + }, + + // signifies particular column has undergone type promotion, and so + // the source column with the given index needs to be promoted to the // specified type Promote { target_type: DataType, + source_index: usize, }, // Signifies that a new column has been inserted before the row @@ -55,14 +58,22 @@ pub(crate) enum EvolutionAction { } #[derive(Debug)] -pub(crate) struct RecordBatchEvolutionProcessor { - operations: Vec, - +struct SchemaAndOps { // Every transformed RecordBatch will have the same schema. We create the // target just once and cache it here. Helpfully, Arc is needed in // the constructor for RecordBatch, so we don't need an expensive copy // each time. - target_schema: Arc, + pub target_schema: Arc, + + // Indicates how each column in the target schema is derived. + pub operations: Vec, +} + +#[derive(Debug)] +pub(crate) struct RecordBatchEvolutionProcessor { + snapshot_schema: Arc, + projected_iceberg_field_ids: Vec, + schema_and_ops: Option, } impl RecordBatchEvolutionProcessor { @@ -70,30 +81,68 @@ impl RecordBatchEvolutionProcessor { /// and Iceberg snapshot schema. Returns Ok(None) if the processor would not be required /// due to the file schema already matching the snapshot schema pub(crate) fn build( - source_schema: &ArrowSchemaRef, - snapshot_schema: &IcebergSchema, + // source_schema: &ArrowSchemaRef, + snapshot_schema: Arc, projected_iceberg_field_ids: &[i32], - ) -> Result> { - let operations: Vec<_> = - Self::generate_operations(source_schema, snapshot_schema, projected_iceberg_field_ids)?; - - Ok(if operations.is_empty() { - None + ) -> Self { + let projected_iceberg_field_ids = if projected_iceberg_field_ids.is_empty() { + // project all fields in table schema order + snapshot_schema + .as_struct() + .fields() + .iter() + .map(|field| field.id) + .collect() } else { - Some(Self { - operations, - target_schema: Arc::new(schema_to_arrow_schema(snapshot_schema)?), - }) - }) - } + projected_iceberg_field_ids.to_vec() + }; - fn target_schema(&self) -> Arc { - self.target_schema.clone() + Self { + snapshot_schema, + projected_iceberg_field_ids, + schema_and_ops: None, + } + + // let (operations, target_schema) = Self::generate_operations_and_schema( + // source_schema, + // snapshot_schema, + // projected_iceberg_field_ids, + // )?; + // + // Ok(if target_schema.as_ref() == source_schema.as_ref() { + // None + // } else { + // Some(Self { + // operations, + // target_schema, + // }) + // }) } - pub(crate) fn process_record_batch(&self, record_batch: RecordBatch) -> Result { + pub(crate) fn process_record_batch( + &mut self, + record_batch: RecordBatch, + ) -> Result { + if self.schema_and_ops.is_none() { + self.schema_and_ops = Some(Self::generate_operations_and_schema( + record_batch.schema_ref(), + self.snapshot_schema.as_ref(), + &self.projected_iceberg_field_ids, + )?); + } + + let Some(SchemaAndOps { + ref target_schema, .. + }) = self.schema_and_ops + else { + return Err(Error::new( + ErrorKind::Unexpected, + "schema_and_ops always created at this point", + )); + }; + Ok(RecordBatch::try_new( - self.target_schema(), + target_schema.clone(), self.transform_columns(record_batch.columns())?, )?) } @@ -101,100 +150,91 @@ impl RecordBatchEvolutionProcessor { // create the (possibly empty) list of `EvolutionOp`s that we need // to apply to the arrays in a record batch with `source_schema` so // that it matches the `snapshot_schema` - fn generate_operations( + fn generate_operations_and_schema( source_schema: &ArrowSchemaRef, snapshot_schema: &IcebergSchema, projected_iceberg_field_ids: &[i32], - ) -> Result> { - let mut ops = vec![]; - - let mapped_unprojected_arrow_schema = schema_to_arrow_schema(snapshot_schema)?; - // need to create a new arrow schema here by selecting fields from mapped_unprojected, + ) -> Result { + let mapped_unprojected_arrow_schema = Arc::new(schema_to_arrow_schema(snapshot_schema)?); + let field_id_to_source_schema_map = + Self::build_field_id_to_arrow_schema_map(source_schema)?; + let field_id_to_mapped_schema_map = + Self::build_field_id_to_arrow_schema_map(&mapped_unprojected_arrow_schema)?; + + // Create a new arrow schema by selecting fields from mapped_unprojected, // in the order of the field ids in projected_iceberg_field_ids - - // right now the below is incorrect if projected_iceberg_field_ids skips any iceberg fields - // or re-orders any - - for &projected_field_id in projected_iceberg_field_ids { - let iceberg_field = snapshot_schema.field_by_id(projected_field_id).ok_or_else(|| { - Error::new( - ErrorKind::Unexpected, - "projected field id not found in snapshot schema", - ) - })?; - let (mapped_arrow_field, _) = Self::get_arrow_field_with_field_id(&mapped_arrow_schema, projected_field_id)?; - let (orig_arrow_field, orig_arrow_field_idx) = Self::get_arrow_field_with_field_id(&source_schema, projected_field_id)?; - - let (arrow_field, add_op_required) = - if source_schema_idx < source_schema.fields().len() { - let orig_arrow_field = source_schema.field(source_schema_idx); - let arrow_field_id: i32 = orig_arrow_field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - "field ID not present in parquet metadata", - ) - })? - .parse() - .map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - format!("field id not parseable as an i32: {}", e), - ) - })?; - (Some(orig_arrow_field), arrow_field_id != projected_field_id) + let fields: Result> = projected_iceberg_field_ids + .iter() + .map(|field_id| { + Ok(field_id_to_mapped_schema_map + .get(field_id) + .ok_or(Error::new(ErrorKind::Unexpected, "field not found"))? + .0 + .clone()) + }) + .collect(); + let target_schema = ArrowSchema::new(fields?); + + let operations: Result> = projected_iceberg_field_ids.iter().map(|field_id|{ + let (target_field, _) = field_id_to_mapped_schema_map.get(field_id).ok_or( + Error::new(ErrorKind::Unexpected, "could not find field in schema") + )?; + let target_type = target_field.data_type(); + + Ok(if let Some((source_field, source_index)) = field_id_to_source_schema_map.get(field_id) { + // column present in source + + if source_field.data_type().equals_datatype(target_type) { + // no promotion required + EvolutionAction::PassThrough { + source_index: *source_index + } } else { - (None, true) - }; + // promotion required + EvolutionAction::Promote { + target_type: target_type.clone(), + source_index: *source_index, + } + } + } else { + // column must be added + let iceberg_field = snapshot_schema.field_by_id(*field_id).ok_or( + Error::new(ErrorKind::Unexpected, "Field not found in snapshot schema") + )?; - if add_op_required { let default_value = if let Some(ref iceberg_default_value) = &iceberg_field.initial_default { let Literal::Primitive(prim_value) = iceberg_default_value else { return Err(Error::new( - ErrorKind::Unexpected, - format!("Default value for column must be primitive type, but encountered {:?}", iceberg_default_value) - )); + ErrorKind::Unexpected, + format!("Default value for column must be primitive type, but encountered {:?}", iceberg_default_value) + )); }; Some(prim_value.clone()) } else { None }; - ops.push(EvolutionOp { - index: source_schema_idx, - action: EvolutionAction::Add { - value: default_value, - target_type: mapped_arrow_field.data_type().clone(), - }, - }) - } else { - if !arrow_field - .unwrap() // will never fail as we only get here if we have Some(field) - .data_type() - .equals_datatype(mapped_arrow_field.data_type()) - { - ops.push(EvolutionOp { - index: source_schema_idx, - action: EvolutionAction::Promote { - target_type: mapped_arrow_field.data_type().clone(), - }, - }) + EvolutionAction::Add { + value: default_value, + target_type: target_type.clone(), } + }) + }).collect(); - source_schema_idx += 1; - } - } - - Ok(ops) + Ok(SchemaAndOps { + operations: operations?, + target_schema: Arc::new(target_schema), + }) } - fn get_arrow_field_with_field_id(arrow_schema: &ArrowSchema, field_id: i32) -> Result<(FieldRef, usize)> { - for (field, idx) in arrow_schema.fields().enumerate().iter() { - let this_field_id: i32 = field + fn build_field_id_to_arrow_schema_map( + source_schema: &SchemaRef, + ) -> Result> { + let mut field_id_to_source_schema = HashMap::new(); + for (source_field_idx, source_field) in source_schema.fields.iter().enumerate() { + let this_field_id = source_field .metadata() .get(PARQUET_FIELD_ID_META_KEY) .ok_or_else(|| { @@ -211,49 +251,47 @@ impl RecordBatchEvolutionProcessor { ) })?; - if this_field_id == field_id { - return Ok((field.clone(), idx)) - } + field_id_to_source_schema + .insert(this_field_id, (source_field.clone(), source_field_idx)); } - Err(Error::new( - ErrorKind::Unexpected, - format!("field with id {} not found in parquet schema", field_id) - )) + Ok(field_id_to_source_schema) } fn transform_columns( &self, columns: &[Arc], ) -> Result>> { - let mut result = Vec::with_capacity(columns.len() + self.operations.len()); - let num_rows = if columns.is_empty() { - 0 - } else { - columns[0].len() + if columns.is_empty() { + return Ok(columns.to_vec()); + } + let num_rows = columns[0].len(); + + let Some(ref schema_and_ops) = self.schema_and_ops else { + return Err(Error::new( + ErrorKind::Unexpected, + "schema_and_ops was None, but should be present", + )); }; - let mut col_idx = 0; - let mut op_idx = 0; - while op_idx < self.operations.len() || col_idx < columns.len() { - if op_idx < self.operations.len() && self.operations[op_idx].index == col_idx { - match &self.operations[op_idx].action { + let result: Result> = schema_and_ops + .operations + .iter() + .map(|op| { + Ok(match op { + EvolutionAction::PassThrough { source_index } => columns[*source_index].clone(), + EvolutionAction::Promote { + target_type, + source_index, + } => cast(&*columns[*source_index], target_type)?, EvolutionAction::Add { target_type, value } => { - result.push(Self::create_column(target_type, value, num_rows)?); + Self::create_column(target_type, value, num_rows)? } - EvolutionAction::Promote { target_type } => { - result.push(cast(&*columns[col_idx], target_type)?); - col_idx += 1; - } - } - op_idx += 1; - } else { - result.push(columns[col_idx].clone()); - col_idx += 1; - } - } + }) + }) + .collect(); - Ok(result) + result } fn create_column( @@ -337,57 +375,52 @@ mod test { use crate::spec::{Literal, NestedField, PrimitiveType, Schema, Type}; #[test] - fn build_returns_none_when_no_schema_migration_required() { - let snapshot_schema = iceberg_table_schema(); + fn build_field_id_to_source_schema_map_works() { let arrow_schema = arrow_schema_already_same_as_target(); - let projected_iceberg_field_ids = [10, 11, 12, 13, 14]; - let inst = RecordBatchEvolutionProcessor::build( - &arrow_schema, - &snapshot_schema, - &projected_iceberg_field_ids, - ) - .unwrap(); + let result = + RecordBatchEvolutionProcessor::build_field_id_to_arrow_schema_map(&arrow_schema) + .unwrap(); + + let expected = HashMap::from_iter([ + (10, (arrow_schema.fields()[0].clone(), 0)), + (11, (arrow_schema.fields()[1].clone(), 1)), + (12, (arrow_schema.fields()[2].clone(), 2)), + (14, (arrow_schema.fields()[3].clone(), 3)), + (15, (arrow_schema.fields()[4].clone(), 4)), + ]); - assert!(inst.is_none()); + assert!(result.eq(&expected)); } #[test] - fn processor_returns_correct_arrow_schema_when_schema_migration_required() { - let snapshot_schema = iceberg_table_schema(); - let arrow_schema = arrow_schema_promotion_addition_and_renaming_required(); - let projected_iceberg_field_ids = [10, 11, 12, 13, 14]; - - let inst = RecordBatchEvolutionProcessor::build( - &arrow_schema, - &snapshot_schema, - &projected_iceberg_field_ids, - ) - .unwrap() - .unwrap(); + fn processor_returns_properly_shaped_record_batch_when_no_schema_migration_required() { + let snapshot_schema = Arc::new(iceberg_table_schema()); + let projected_iceberg_field_ids = [13, 14]; + + let mut inst = + RecordBatchEvolutionProcessor::build(snapshot_schema, &projected_iceberg_field_ids); - let result = inst.target_schema(); + let result = inst + .process_record_batch(source_record_batch_no_migration_required()) + .unwrap(); - assert_eq!(result, arrow_schema_already_same_as_target()); + let expected = source_record_batch_no_migration_required(); + + assert_eq!(result, expected); } #[test] fn processor_returns_properly_shaped_record_batch_when_schema_migration_required() { - let snapshot_schema = iceberg_table_schema(); - let arrow_schema = arrow_schema_promotion_addition_and_renaming_required(); - let projected_iceberg_field_ids = [10, 11, 12, 13, 14]; - - let inst = RecordBatchEvolutionProcessor::build( - &arrow_schema, - &snapshot_schema, - &projected_iceberg_field_ids, - ) - .unwrap() - .unwrap(); + let snapshot_schema = Arc::new(iceberg_table_schema()); + let projected_iceberg_field_ids = [10, 11, 12, 14, 15]; // a, b, c, e, f + + let mut inst = + RecordBatchEvolutionProcessor::build(snapshot_schema, &projected_iceberg_field_ids); let result = inst.process_record_batch(source_record_batch()).unwrap(); - let expected = expected_record_batch(); + let expected = expected_record_batch_migration_required(); assert_eq!(result, expected); } @@ -396,43 +429,59 @@ mod test { RecordBatch::try_new( arrow_schema_promotion_addition_and_renaming_required(), vec![ - Arc::new(Int32Array::from(vec![Some(1001), Some(1002), Some(1003)])), + Arc::new(Int32Array::from(vec![Some(1001), Some(1002), Some(1003)])), // b Arc::new(Float32Array::from(vec![ Some(12.125), Some(23.375), Some(34.875), - ])), + ])), // c + Arc::new(Int32Array::from(vec![Some(2001), Some(2002), Some(2003)])), // d + Arc::new(StringArray::from(vec![ + Some("Apache"), + Some("Iceberg"), + Some("Rocks"), + ])), // e + ], + ) + .unwrap() + } + + pub fn source_record_batch_no_migration_required() -> RecordBatch { + RecordBatch::try_new( + arrow_schema_no_promotion_addition_or_renaming_required(), + vec![ + Arc::new(Int32Array::from(vec![Some(2001), Some(2002), Some(2003)])), // d Arc::new(StringArray::from(vec![ Some("Apache"), Some("Iceberg"), Some("Rocks"), - ])), + ])), // e ], ) .unwrap() } - pub fn expected_record_batch() -> RecordBatch { + pub fn expected_record_batch_migration_required() -> RecordBatch { RecordBatch::try_new(arrow_schema_already_same_as_target(), vec![ Arc::new(StringArray::from(Vec::>::from([ None, None, None, - ]))), - Arc::new(Int64Array::from(vec![Some(1001), Some(1002), Some(1003)])), + ]))), // a + Arc::new(Int64Array::from(vec![Some(1001), Some(1002), Some(1003)])), // b Arc::new(Float64Array::from(vec![ Some(12.125), Some(23.375), Some(34.875), - ])), + ])), // c Arc::new(StringArray::from(vec![ Some("Apache"), Some("Iceberg"), Some("Rocks"), - ])), + ])), // e (d skipped by projection) Arc::new(StringArray::from(vec![ Some("(╯°□°)╯"), Some("(╯°□°)╯"), Some("(╯°□°)╯"), - ])), + ])), // f ]) .unwrap() } @@ -444,8 +493,9 @@ mod test { NestedField::optional(10, "a", Type::Primitive(PrimitiveType::String)).into(), NestedField::required(11, "b", Type::Primitive(PrimitiveType::Long)).into(), NestedField::required(12, "c", Type::Primitive(PrimitiveType::Double)).into(), - NestedField::optional(13, "d", Type::Primitive(PrimitiveType::String)).into(), - NestedField::required(14, "e", Type::Primitive(PrimitiveType::String)) + NestedField::required(13, "d", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(14, "e", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(15, "f", Type::Primitive(PrimitiveType::String)) .with_initial_default(Literal::string("(╯°□°)╯")) .into(), ]) @@ -458,8 +508,8 @@ mod test { simple_field("a", DataType::Utf8, true, "10"), simple_field("b", DataType::Int64, false, "11"), simple_field("c", DataType::Float64, false, "12"), - simple_field("d", DataType::Utf8, true, "13"), - simple_field("e", DataType::Utf8, false, "14"), + simple_field("e", DataType::Utf8, true, "14"), + simple_field("f", DataType::Utf8, false, "15"), ])) } @@ -467,7 +517,15 @@ mod test { Arc::new(ArrowSchema::new(vec![ simple_field("b", DataType::Int32, false, "11"), simple_field("c", DataType::Float32, false, "12"), - simple_field("d_old", DataType::Utf8, true, "13"), + simple_field("d", DataType::Int32, false, "13"), + simple_field("e_old", DataType::Utf8, true, "14"), + ])) + } + + fn arrow_schema_no_promotion_addition_or_renaming_required() -> Arc { + Arc::new(ArrowSchema::new(vec![ + simple_field("d", DataType::Int32, false, "13"), + simple_field("e", DataType::Utf8, true, "14"), ])) } diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index bc7f10a0e..4ca06085d 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -885,6 +885,33 @@ pub struct FileScanTask { pub predicate: Option, } +impl FileScanTask { + /// Returns the data file path of this file scan task. + pub fn data_file_path(&self) -> &str { + &self.data_file_path + } + + /// Returns the project field id of this file scan task. + pub fn project_field_ids(&self) -> &[i32] { + &self.project_field_ids + } + + /// Returns the predicate of this file scan task. + pub fn predicate(&self) -> Option<&BoundPredicate> { + self.predicate.as_ref() + } + + /// Returns the schema of this file scan task as a reference + pub fn schema(&self) -> &Schema { + &self.schema + } + + /// Returns the schema of this file scan task as a SchemaRef + pub fn schema_ref(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use std::collections::HashMap;