diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala index 113ad8bf82..0820e505ed 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndex.scala @@ -12,6 +12,7 @@ import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{Expression, Predicate} import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory} import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.functions.isnull import org.apache.spark.sql.types.StructType /** @@ -32,14 +33,19 @@ case class FlintSparkSkippingFileIndex( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { - // TODO: try to avoid the list call if no hybrid scan + // TODO: make this listFile call only in hybrid scan mode val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters) + val selectedFiles = + if (FlintSparkConf().isHybridScanEnabled) { + selectFilesFromIndexAndSource(partitions) + } else { + selectFilesFromIndexOnly() + } - if (FlintSparkConf().isHybridScanEnabled) { - scanFilesFromIndexAndSource(partitions) - } else { - scanFilesFromIndex(partitions) - } + // Keep partition files present in selected file list above + partitions + .map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f)))) + .filter(p => p.files.nonEmpty) } override def rootPaths: Seq[Path] = baseFileIndex.rootPaths @@ -52,23 +58,42 @@ case class FlintSparkSkippingFileIndex( override def partitionSchema: StructType = baseFileIndex.partitionSchema - private def scanFilesFromIndexAndSource( - partitions: Seq[PartitionDirectory]): Seq[PartitionDirectory] = { - Seq.empty - } - - private def scanFilesFromIndex(partitions: Seq[PartitionDirectory]): Seq[PartitionDirectory] = { - val selectedFiles = - indexScan - .filter(new Column(indexFilter)) - .select(FILE_PATH_COLUMN) - .collect - .map(_.getString(0)) - .toSet + /* + * Left join source partitions and index data to keep unrefreshed source files: + * Express the logic in SQL: + * SELECT left.file_path + * FROM partitions AS left + * LEFT OUTER JOIN indexScan AS right + * ON left.file_path = right.file_path + * WHERE right.file_path IS NULL + * OR [indexFilter] + */ + private def selectFilesFromIndexAndSource(partitions: Seq[PartitionDirectory]): Set[String] = { + val sparkSession = indexScan.sparkSession + import sparkSession.implicits._ partitions - .map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f)))) - .filter(p => p.files.nonEmpty) + .flatMap(_.files.map(f => f.getPath.toString)) + .toDF(FILE_PATH_COLUMN) + .join(indexScan, Seq(FILE_PATH_COLUMN), "left") + .filter(isnull(indexScan(FILE_PATH_COLUMN)) || new Column(indexFilter)) + .select(FILE_PATH_COLUMN) + .collect() + .map(_.getString(0)) + .toSet + } + + /* + * Consider file paths in index data alone. In this case, index filter can be pushed down + * to index store. + */ + private def selectFilesFromIndexOnly(): Set[String] = { + indexScan + .filter(new Column(indexFilter)) + .select(FILE_PATH_COLUMN) + .collect + .map(_.getString(0)) + .toSet } private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = { diff --git a/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala b/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala index 8e7b1c2a32..df661cdbcf 100644 --- a/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala +++ b/flint/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingFileIndexSuite.scala @@ -17,21 +17,82 @@ import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate} import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory} +import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class FlintSparkSkippingFileIndexSuite extends FlintSuite with Matchers { - test("should skip unknown source files in non-hybrid-scan mode") { + /** Test source partition data. */ + private val partition1 = "partition-1" -> Seq("file-1", "file-2") + private val partition2 = "partition-2" -> Seq("file-3") + + /** Test index data schema. */ + private val schema = Map((FILE_PATH_COLUMN, StringType), ("year", IntegerType)) + + test("should keep files returned from index") { + assertFlintFileIndex() + .withSourceFiles(Map(partition1)) + .withIndexData(schema, Seq(Row("file-1", 2023), Row("file-2", 2022))) + .withIndexFilter(col("year") === 2023) + .shouldScanSourceFiles(Map("partition-1" -> Seq("file-1"))) + } + + test("should keep files of multiple partitions returned from index") { + assertFlintFileIndex() + .withSourceFiles(Map(partition1, partition2)) + .withIndexData(schema, Seq(Row("file-1", 2023), Row("file-2", 2022), Row("file-3", 2023))) + .withIndexFilter(col("year") === 2023) + .shouldScanSourceFiles(Map("partition-1" -> Seq("file-1"), "partition-2" -> Seq("file-3"))) + } + + test("should skip unrefreshed source files by default") { assertFlintFileIndex() - .withSourceFiles(Map("partition-1" -> Seq("file-1", "file-2"))) + .withSourceFiles(Map(partition1)) .withIndexData( - Map((FILE_PATH_COLUMN, StringType), ("year", IntegerType)), - Seq(Row("file-1", 2023), Row("file-2", 2022))) + schema, + Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet + ) .withIndexFilter(col("year") === 2023) .shouldScanSourceFiles(Map("partition-1" -> Seq("file-1"))) } + test("should not skip unrefreshed source files in hybrid-scan mode") { + withHybridScanEnabled { + assertFlintFileIndex() + .withSourceFiles(Map(partition1)) + .withIndexData( + schema, + Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet + ) + .withIndexFilter(col("year") === 2023) + .shouldScanSourceFiles(Map("partition-1" -> Seq("file-1", "file-2"))) + } + } + + test("should not skip unrefreshed source files of multiple partitions in hybrid-scan mode") { + withHybridScanEnabled { + assertFlintFileIndex() + .withSourceFiles(Map(partition1, partition2)) + .withIndexData( + schema, + Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet + ) + .withIndexFilter(col("year") === 2023) + .shouldScanSourceFiles( + Map("partition-1" -> Seq("file-1", "file-2"), "partition-2" -> Seq("file-3"))) + } + } + + private def withHybridScanEnabled(block: => Unit): Unit = { + setFlintSparkConf(HYBRID_SCAN_ENABLED, "true") + try { + block + } finally { + setFlintSparkConf(HYBRID_SCAN_ENABLED, "false") + } + } + private def assertFlintFileIndex(): AssertionHelper = { new AssertionHelper }