From 0fbc490b9ce1d5678848701a661c33eff928b024 Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Sun, 29 Sep 2019 00:52:01 +0200 Subject: [PATCH] Generate null values for nullable fields --- .../spark/testing/DataframeGenerator.scala | 24 ++++--- .../spark/testing/SampleScalaCheckTest.scala | 62 ++++++++++++++----- 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/core/src/main/1.3/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala b/core/src/main/1.3/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala index 2b7f4910..40a6df58 100644 --- a/core/src/main/1.3/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala +++ b/core/src/main/1.3/scala/com/holdenkarau/spark/testing/DataframeGenerator.scala @@ -95,20 +95,20 @@ object DataframeGenerator { List[Gen[Any]] = { val generatorMap = userGenerators.map( generator => (generator.columnName -> generator)).toMap - (0 until fields.length).toList.map(index => { - if (generatorMap.contains(fields(index).name)) { - generatorMap.get(fields(index).name) match { + fields.toList.map { field => + if (generatorMap.contains(field.name)) { + generatorMap.get(field.name) match { case Some(gen: Column) => gen.gen - case Some(list: ColumnList) => getGenerator(fields(index).dataType, list.gen) + case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable) } } - else getGenerator(fields(index).dataType) - }) + else getGenerator(field.dataType, nullable = field.nullable) + } } private def getGenerator( - dataType: DataType, generators: Seq[ColumnGenerator] = Seq()): Gen[Any] = { - dataType match { + dataType: DataType, generators: Seq[ColumnGenerator] = Seq(), nullable: Boolean = false): Gen[Any] = { + val nonNullGen = dataType match { case ByteType => Arbitrary.arbitrary[Byte] case ShortType => Arbitrary.arbitrary[Short] case IntegerType => Arbitrary.arbitrary[Int] @@ -121,12 +121,12 @@ object DataframeGenerator { case TimestampType => Arbitrary.arbLong.arbitrary.map(new Timestamp(_)) case DateType => Arbitrary.arbLong.arbitrary.map(new Date(_)) case arr: ArrayType => { - val elementGenerator = getGenerator(arr.elementType) + val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull) Gen.listOf(elementGenerator) } case map: MapType => { val keyGenerator = getGenerator(map.keyType) - val valueGenerator = getGenerator(map.valueType) + val valueGenerator = getGenerator(map.valueType, nullable = map.valueContainsNull) val keyValueGenerator: Gen[(Any, Any)] = for { key <- keyGenerator value <- valueGenerator @@ -139,6 +139,10 @@ object DataframeGenerator { case _ => throw new UnsupportedOperationException( s"Type: $dataType not supported") } + if (nullable) + Gen.oneOf(nonNullGen, Gen.const(null)) + else + nonNullGen } } diff --git a/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala b/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala index eb907387..09b7308f 100644 --- a/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala +++ b/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleScalaCheckTest.scala @@ -316,24 +316,24 @@ class SampleScalaCheckTest extends FunSuite check(property) } + val fields = StructField("byteType", ByteType) :: + StructField("shortType", ShortType) :: + StructField("intType", IntegerType) :: + StructField("longType", LongType) :: + StructField("doubleType", DoubleType) :: + StructField("stringType", StringType) :: + StructField("binaryType", BinaryType) :: + StructField("booleanType", BooleanType) :: + StructField("timestampType", TimestampType) :: + StructField("dateType", DateType) :: + StructField("arrayType", ArrayType(TimestampType)) :: + StructField("mapType", + MapType(LongType, TimestampType, valueContainsNull = true)) :: + StructField("structType", + StructType(StructField("timestampType", TimestampType) :: Nil)) :: Nil test("second dataframe's evaluation has the same values as first") { implicit val generatorDrivenConfig = PropertyCheckConfig(minSize = 1, maxSize = 1) - val fields = StructField("byteType", ByteType) :: - StructField("shortType", ShortType) :: - StructField("intType", IntegerType) :: - StructField("longType", LongType) :: - StructField("doubleType", DoubleType) :: - StructField("stringType", StringType) :: - StructField("binaryType", BinaryType) :: - StructField("booleanType", BooleanType) :: - StructField("timestampType", TimestampType) :: - StructField("dateType", DateType) :: - StructField("arrayType", ArrayType(TimestampType)) :: - StructField("mapType", - MapType(LongType, TimestampType, valueContainsNull = true)) :: - StructField("structType", - StructType(StructField("timestampType", TimestampType) :: Nil)) :: Nil val sqlContext = new SQLContext(sc) val dataframeGen = @@ -352,6 +352,38 @@ class SampleScalaCheckTest extends FunSuite check(property) } + test("nullable fields contain null values as well") { + implicit val generatorDrivenConfig = + PropertyCheckConfig(minSize = 1, maxSize = 1) + val nullableFields = fields.map(f => f.copy(nullable = true, name = s"${f.name}Nullable")) + val sqlContext = new SQLContext(sc) + val allFields = fields ::: nullableFields + val dataframeGen = + DataframeGenerator.arbitraryDataFrame(sqlContext, StructType(allFields)) + + val property = + forAll(Gen.resize(100, dataframeGen.arbitrary)) { + dataframe => { + allFields.forall { f => + val colValues = dataframe.select(f.name).collect().map(_.get(0)) + if (f.nullable) + colValues.contains(null) || + // Unfortunately, dataframeGen.arbitrary sometimes generates DataFrames where all + // rows have exactly identical values. + // In that case, even generating many rows doesn't help to get some nulls... + // To work around this we check if we generated at least some distinct values. + colValues.distinct.size < 4 || + // This is needed for Array-valued fields where .distinct returns all values, even when + // they're identical. + colValues.size == colValues.distinct.size + else + !colValues.contains(null) + } + } + } + + check(property) + } private def filterOne(rdd: RDD[String]): RDD[Int] = { rdd.filter(_.length > 2).map(_.length)