Skip to content

Commit

Permalink
Bulk Load CDK: Add Avro/Parquet Mapper Pipelines to Writers
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Nov 7, 2024
1 parent 497be1c commit 99e4115
Show file tree
Hide file tree
Showing 24 changed files with 319 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MockStreamLoader(override val stream: DestinationStream) : StreamLoader {
stream.generationId,
it.data as ObjectValue,
OutputRecord.Meta(
changes = it.meta?.changes ?: mutableListOf(),
changes = it.meta?.changes ?: listOf(),
syncId = stream.syncId
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ interface AirbyteSchemaMapper {
fun map(schema: AirbyteType): AirbyteType
}

class AirbyteSchemaNoopMapper : AirbyteSchemaMapper {
override fun map(schema: AirbyteType): AirbyteType = schema
}

interface AirbyteSchemaIdentityMapper : AirbyteSchemaMapper {
override fun map(schema: AirbyteType): AirbyteType =
when (schema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason

interface AirbyteValueMapper {
val collectedChanges: List<DestinationRecord.Change>

fun map(
value: AirbyteValue,
schema: AirbyteType,
path: List<String> = emptyList(),
nullable: Boolean = false
): AirbyteValue
}

Expand All @@ -26,6 +26,7 @@ class AirbyteValueNoopMapper : AirbyteValueMapper {
value: AirbyteValue,
schema: AirbyteType,
path: List<String>,
nullable: Boolean
): AirbyteValue = value
}

Expand All @@ -49,43 +50,58 @@ open class AirbyteValueIdentityMapper : AirbyteValueMapper {
value: AirbyteValue,
schema: AirbyteType,
path: List<String>,
nullable: Boolean,
): AirbyteValue =
try {
when (schema) {
is ObjectType -> mapObject(value as ObjectValue, schema, path)
is ObjectTypeWithoutSchema ->
mapObjectWithoutSchema(value as ObjectValue, schema, path)
is ObjectTypeWithEmptySchema ->
mapObjectWithEmptySchema(value as ObjectValue, schema, path)
is ArrayType -> mapArray(value as ArrayValue, schema, path)
is ArrayTypeWithoutSchema ->
mapArrayWithoutSchema(value as ArrayValue, schema, path)
is UnionType -> mapUnion(value, schema, path)
is BooleanType -> mapBoolean(value as BooleanValue, path)
is NumberType -> mapNumber(value as NumberValue, path)
is StringType -> mapString(value as StringValue, path)
is IntegerType -> mapInteger(value as IntegerValue, path)
is DateType -> mapDate(value as DateValue, path)
is TimeTypeWithTimezone -> mapTimeWithTimezone(value as TimeValue, path)
is TimeTypeWithoutTimezone -> mapTimeWithoutTimezone(value as TimeValue, path)
is TimestampTypeWithTimezone ->
mapTimestampWithTimezone(value as TimestampValue, path)
is TimestampTypeWithoutTimezone ->
mapTimestampWithoutTimezone(value as TimestampValue, path)
is UnknownType -> {
collectFailure(path)
mapNull(path)
}
if (value is NullValue) {
if (!nullable) {
throw IllegalStateException(
"null value for non-nullable field at path: ${path.joinToString(".")}"
)
}
} catch (e: Exception) {
collectFailure(path)
mapNull(path)
}
} else
try {
when (schema) {
is ObjectType -> mapObject(value as ObjectValue, schema, path)
is ObjectTypeWithoutSchema ->
mapObjectWithoutSchema(value as ObjectValue, schema, path)
is ObjectTypeWithEmptySchema ->
mapObjectWithEmptySchema(value as ObjectValue, schema, path)
is ArrayType -> mapArray(value as ArrayValue, schema, path)
is ArrayTypeWithoutSchema ->
mapArrayWithoutSchema(value as ArrayValue, schema, path)
is UnionType -> mapUnion(value, schema, path)
is BooleanType -> mapBoolean(value as BooleanValue, path)
is NumberType -> mapNumber(value as NumberValue, path)
is StringType -> mapString(value as StringValue, path)
is IntegerType -> mapInteger(value as IntegerValue, path)
is DateType -> mapDate(value as DateValue, path)
is TimeTypeWithTimezone ->
mapTimeWithTimezone(
value as TimeValue,
path,
)
is TimeTypeWithoutTimezone ->
mapTimeWithoutTimezone(
value as TimeValue,
path,
)
is TimestampTypeWithTimezone ->
mapTimestampWithTimezone(value as TimestampValue, path)
is TimestampTypeWithoutTimezone ->
mapTimestampWithoutTimezone(value as TimestampValue, path)
is UnknownType -> mapUnknown(value as UnknownValue, path)
}
} catch (e: Exception) {
collectFailure(path)
map(NullValue, schema, path, nullable)
}

open fun mapObject(value: ObjectValue, schema: ObjectType, path: List<String>): AirbyteValue {
val values = LinkedHashMap<String, AirbyteValue>()
schema.properties.forEach { (name, field) ->
values[name] = map(value.values[name] ?: NullValue, field.type, path + name)
values[name] =
map(value.values[name] ?: NullValue, field.type, path + name, field.nullable)
}
return ObjectValue(values)
}
Expand All @@ -105,7 +121,7 @@ open class AirbyteValueIdentityMapper : AirbyteValueMapper {
open fun mapArray(value: ArrayValue, schema: ArrayType, path: List<String>): AirbyteValue {
return ArrayValue(
value.values.mapIndexed { index, element ->
map(element, schema.items.type, path + "[$index]")
map(element, schema.items.type, path + "[$index]", schema.items.nullable)
}
)
}
Expand Down Expand Up @@ -140,4 +156,6 @@ open class AirbyteValueIdentityMapper : AirbyteValueMapper {
value

open fun mapNull(path: List<String>): AirbyteValue = NullValue

open fun mapUnknown(value: UnknownValue, path: List<String>): AirbyteValue = value
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import io.airbyte.cdk.load.message.DestinationRecord.Meta
import java.util.*

class DestinationRecordToAirbyteValueWithMeta(val stream: DestinationStream) {
fun convert(data: AirbyteValue, emittedAtMs: Long, meta: DestinationRecord.Meta?): ObjectValue {
fun convert(data: AirbyteValue, emittedAtMs: Long, meta: Meta?): ObjectValue {
return ObjectValue(
linkedMapOf(
Meta.COLUMN_NAME_AB_RAW_ID to StringValue(UUID.randomUUID().toString()),
Expand Down Expand Up @@ -41,5 +41,10 @@ class DestinationRecordToAirbyteValueWithMeta(val stream: DestinationStream) {
}
}

fun Pair<AirbyteValue, List<DestinationRecord.Change>>.withAirbyteMeta(
stream: DestinationStream,
emittedAtMs: Long
) = DestinationRecordToAirbyteValueWithMeta(stream).convert(first, emittedAtMs, Meta(second))

fun DestinationRecord.dataWithAirbyteMeta(stream: DestinationStream) =
DestinationRecordToAirbyteValueWithMeta(stream).convert(data, emittedAtMs, meta)
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class MapperPipeline(
finalSchema = schemas.last()
}

fun map(data: AirbyteValue): Pair<AirbyteValue, List<Change>> {
fun map(data: AirbyteValue, changes: List<Change>? = null): Pair<AirbyteValue, List<Change>> {
val results =
schemasWithMappers.runningFold(data) { value, (schema, mapper) ->
mapper.map(value, schema)
}
val changesFlattened =
schemasWithMappers.flatMap { it.second.collectedChanges }.toSet().toList()
schemasWithMappers
.flatMap { it.second.collectedChanges + (changes ?: emptyList()) }
.toSet()
.toList()
return results.last() to changesFlattened
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ class MergeUnions : AirbyteSchemaIdentityMapper {
continue
}

if (existingField != field) {
throw IllegalArgumentException(
"Cannot merge unions of objects with different types for the same field"
)
}
// Combine the fields, recursively merging unions, object fields, etc
val mergedFields = mapUnion(UnionType(listOf(existingField.type, field.type)))
newProperties[name] =
FieldType(mergedFields, existingField.nullable || field.nullable)

// If the fields are identical, we can just keep the existing field
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SchemalessTypesToJson : AirbyteSchemaIdentityMapper {
override fun mapObjectWithEmptySchema(schema: ObjectTypeWithEmptySchema): AirbyteType =
StringType
override fun mapArrayWithoutSchema(schema: ArrayTypeWithoutSchema): AirbyteType = StringType
override fun mapUnknown(schema: UnknownType): AirbyteType = StringType
}

class SchemalessValuesToJson : AirbyteValueIdentityMapper() {
Expand All @@ -30,4 +31,6 @@ class SchemalessValuesToJson : AirbyteValueIdentityMapper() {
schema: ArrayTypeWithoutSchema,
path: List<String>
): AirbyteValue = value.toJson().serializeToString().let(::StringValue)
override fun mapUnknown(value: UnknownValue, path: List<String>): AirbyteValue =
StringValue(value.what)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,6 @@ import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit

class TimeStringTypeToIntegerType : AirbyteSchemaIdentityMapper {
override fun mapDate(schema: DateType): AirbyteType = IntegerType
override fun mapTimeTypeWithTimezone(schema: TimeTypeWithTimezone): AirbyteType = IntegerType
override fun mapTimeTypeWithoutTimezone(schema: TimeTypeWithoutTimezone): AirbyteType =
IntegerType
override fun mapTimestampTypeWithTimezone(schema: TimestampTypeWithTimezone): AirbyteType =
IntegerType
override fun mapTimestampTypeWithoutTimezone(
schema: TimestampTypeWithoutTimezone
): AirbyteType = IntegerType
}

/**
* NOTE: To keep parity with the old avro/parquet code, we will always first try to parse the value
* as with timezone, then fall back to without. But in theory we should be more strict.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ data class DestinationRecord(
serialized = "",
)

data class Meta(val changes: MutableList<Change> = mutableListOf()) {
data class Meta(val changes: List<Change> = mutableListOf()) {
companion object {
const val COLUMN_NAME_AB_RAW_ID: String = "_airbyte_raw_id"
const val COLUMN_NAME_AB_EXTRACTED_AT: String = "_airbyte_extracted_at"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class AirbyteValueIdentityMapperTest {
.with(
TimestampValue("2021-01-01T12:00:00Z"),
TimeTypeWithTimezone,
nameOverride = "bad"
nameOverride = "bad",
nullable = true
)
.build()
val mapper = AirbyteValueIdentityMapper()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,16 @@ class MapperPipelineTest {

Assertions.assertEquals(2, changes.size, "two failures were captured")
}

@Test
fun testFailedMappingThrowsOnNonNullable() {
val (inputValue, inputSchema, _) =
ValueTestBuilder<Root>()
.with(IntegerValue(2), IntegerType, NullValue, nullable = false) // fail: reject 2
.build()

val pipeline = makePipeline(inputSchema)

Assertions.assertThrows(IllegalStateException::class.java) { pipeline.map(inputValue) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

class MergeUnionsTest {
@Test
Expand Down Expand Up @@ -41,10 +40,25 @@ class MergeUnionsTest {
}

@Test
fun testNameClashFails() {
val (inputSchema, _) =
fun testNameClash() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.withUnion()
.withUnion(
expectedInstead =
FieldType(
ObjectType(
properties =
linkedMapOf(
"foo" to
FieldType(
UnionType(listOf(StringType, IntegerType)),
false
)
)
),
false
)
)
.withRecord()
.with(StringType, nameOverride = "foo")
.endRecord()
Expand All @@ -53,7 +67,8 @@ class MergeUnionsTest {
.endRecord()
.endUnion()
.build()
assertThrows<IllegalArgumentException> { MergeUnions().map(inputSchema) }
val output = MergeUnions().map(inputSchema)
Assertions.assertEquals(expectedOutput, output)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

Expand Down Expand Up @@ -120,31 +118,4 @@ class TimeStringToIntegerTest {
)
}
}

@Test
fun testBasicSchemaBehavior() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.with(DateType, IntegerType)
.withRecord()
.with(TimestampTypeWithTimezone, IntegerType)
.endRecord()
.with(TimestampTypeWithoutTimezone, IntegerType)
.withRecord()
.with(TimeTypeWithTimezone, IntegerType)
.withRecord()
.with(TimeTypeWithoutTimezone, IntegerType)
.endRecord()
.endRecord()
.withUnion(
expectedInstead =
FieldType(UnionType(listOf(IntegerType, IntegerType)), nullable = false)
)
.with(DateType)
.with(TimeTypeWithTimezone)
.endUnion()
.build()
val output = TimeStringTypeToIntegerType().map(inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ data class OutputRecord(
* that we write to the destination.
*/
data class Meta(
val changes: MutableList<Change> = mutableListOf(),
val changes: List<Change> = listOf(),
val syncId: Long? = null,
)

Expand Down
Loading

0 comments on commit 99e4115

Please sign in to comment.