Skip to content

Commit

Permalink
Generate null values for nullable fields
Browse files Browse the repository at this point in the history
  • Loading branch information
nightscape committed Sep 28, 2019
1 parent 4f130b0 commit 0fbc490
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ object DataframeGenerator {
List[Gen[Any]] = {
val generatorMap = userGenerators.map(
generator => (generator.columnName -> generator)).toMap
(0 until fields.length).toList.map(index => {
if (generatorMap.contains(fields(index).name)) {
generatorMap.get(fields(index).name) match {
fields.toList.map { field =>
if (generatorMap.contains(field.name)) {
generatorMap.get(field.name) match {
case Some(gen: Column) => gen.gen
case Some(list: ColumnList) => getGenerator(fields(index).dataType, list.gen)
case Some(list: ColumnList) => getGenerator(field.dataType, list.gen, nullable = field.nullable)
}
}
else getGenerator(fields(index).dataType)
})
else getGenerator(field.dataType, nullable = field.nullable)
}
}

private def getGenerator(
dataType: DataType, generators: Seq[ColumnGenerator] = Seq()): Gen[Any] = {
dataType match {
dataType: DataType, generators: Seq[ColumnGenerator] = Seq(), nullable: Boolean = false): Gen[Any] = {
val nonNullGen = dataType match {
case ByteType => Arbitrary.arbitrary[Byte]
case ShortType => Arbitrary.arbitrary[Short]
case IntegerType => Arbitrary.arbitrary[Int]
Expand All @@ -121,12 +121,12 @@ object DataframeGenerator {
case TimestampType => Arbitrary.arbLong.arbitrary.map(new Timestamp(_))
case DateType => Arbitrary.arbLong.arbitrary.map(new Date(_))
case arr: ArrayType => {
val elementGenerator = getGenerator(arr.elementType)
val elementGenerator = getGenerator(arr.elementType, nullable = arr.containsNull)
Gen.listOf(elementGenerator)
}
case map: MapType => {
val keyGenerator = getGenerator(map.keyType)
val valueGenerator = getGenerator(map.valueType)
val valueGenerator = getGenerator(map.valueType, nullable = map.valueContainsNull)
val keyValueGenerator: Gen[(Any, Any)] = for {
key <- keyGenerator
value <- valueGenerator
Expand All @@ -139,6 +139,10 @@ object DataframeGenerator {
case _ => throw new UnsupportedOperationException(
s"Type: $dataType not supported")
}
if (nullable)
Gen.oneOf(nonNullGen, Gen.const(null))
else
nonNullGen
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,24 +316,24 @@ class SampleScalaCheckTest extends FunSuite
check(property)
}

val fields = StructField("byteType", ByteType) ::
StructField("shortType", ShortType) ::
StructField("intType", IntegerType) ::
StructField("longType", LongType) ::
StructField("doubleType", DoubleType) ::
StructField("stringType", StringType) ::
StructField("binaryType", BinaryType) ::
StructField("booleanType", BooleanType) ::
StructField("timestampType", TimestampType) ::
StructField("dateType", DateType) ::
StructField("arrayType", ArrayType(TimestampType)) ::
StructField("mapType",
MapType(LongType, TimestampType, valueContainsNull = true)) ::
StructField("structType",
StructType(StructField("timestampType", TimestampType) :: Nil)) :: Nil
test("second dataframe's evaluation has the same values as first") {
implicit val generatorDrivenConfig =
PropertyCheckConfig(minSize = 1, maxSize = 1)
val fields = StructField("byteType", ByteType) ::
StructField("shortType", ShortType) ::
StructField("intType", IntegerType) ::
StructField("longType", LongType) ::
StructField("doubleType", DoubleType) ::
StructField("stringType", StringType) ::
StructField("binaryType", BinaryType) ::
StructField("booleanType", BooleanType) ::
StructField("timestampType", TimestampType) ::
StructField("dateType", DateType) ::
StructField("arrayType", ArrayType(TimestampType)) ::
StructField("mapType",
MapType(LongType, TimestampType, valueContainsNull = true)) ::
StructField("structType",
StructType(StructField("timestampType", TimestampType) :: Nil)) :: Nil

val sqlContext = new SQLContext(sc)
val dataframeGen =
Expand All @@ -352,6 +352,38 @@ class SampleScalaCheckTest extends FunSuite

check(property)
}
test("nullable fields contain null values as well") {
implicit val generatorDrivenConfig =
PropertyCheckConfig(minSize = 1, maxSize = 1)
val nullableFields = fields.map(f => f.copy(nullable = true, name = s"${f.name}Nullable"))
val sqlContext = new SQLContext(sc)
val allFields = fields ::: nullableFields
val dataframeGen =
DataframeGenerator.arbitraryDataFrame(sqlContext, StructType(allFields))

val property =
forAll(Gen.resize(100, dataframeGen.arbitrary)) {
dataframe => {
allFields.forall { f =>
val colValues = dataframe.select(f.name).collect().map(_.get(0))
if (f.nullable)
colValues.contains(null) ||
// Unfortunately, dataframeGen.arbitrary sometimes generates DataFrames where all
// rows have exactly identical values.
// In that case, even generating many rows doesn't help to get some nulls...
// To work around this we check if we generated at least some distinct values.
colValues.distinct.size < 4 ||
// This is needed for Array-valued fields where .distinct returns all values, even when
// they're identical.
colValues.size == colValues.distinct.size
else
!colValues.contains(null)
}
}
}

check(property)
}

private def filterOne(rdd: RDD[String]): RDD[Int] = {
rdd.filter(_.length > 2).map(_.length)
Expand Down

0 comments on commit 0fbc490

Please sign in to comment.