From e09ff1b7ca0c9d33c91d85a1f8c890e8534fb8ab Mon Sep 17 00:00:00 2001 From: Nuvindu Date: Thu, 9 May 2024 13:23:14 +0530 Subject: [PATCH] Remove unnecessary type tag declarations --- .../java/io/ballerina/lib/avro/Utils.java | 36 ++++++------ .../visitor/DeserializeRecordVisitor.java | 10 +--- .../visitor/DeserializeVisitor.java | 32 ++++------- .../serialize/visitor/SerializeVisitor.java | 57 +++++++++---------- 4 files changed, 54 insertions(+), 81 deletions(-) diff --git a/native/src/main/java/io/ballerina/lib/avro/Utils.java b/native/src/main/java/io/ballerina/lib/avro/Utils.java index 3a462af..fbd4039 100644 --- a/native/src/main/java/io/ballerina/lib/avro/Utils.java +++ b/native/src/main/java/io/ballerina/lib/avro/Utils.java @@ -38,35 +38,31 @@ private Utils() { public static final String ERROR_TYPE = "Error"; public static final String SERIALIZATION_ERROR = "Avro serialization error"; public static final String DESERIALIZATION_ERROR = "Avro deserialization error"; - public static final int STRING_TYPE = 5; - public static final int ARRAY_TYPE = 32; - public static final int MAP_TYPE = 27; - public static final int INTERSECTION_TYPE = 34; - public static final int REFERENCE_TYPE = 53; - public static final int RECORD_TYPE = 24; - public static final int INTEGER_TYPE = 1; - public static final int FLOAT_TYPE = 3; - public static final int BOOLEAN_TYPE = 6; public static BError createError(String message, Throwable throwable) { BError cause = ErrorCreator.createError(throwable); return ErrorCreator.createError(getModule(), ERROR_TYPE, StringUtils.fromString(message), cause, null); } - public static Type getMutableType(IntersectionType intersectionType) { - for (Type type : intersectionType.getConstituentTypes()) { - Type referredType = TypeUtils.getImpliedType(type); - if (referredType.getTag() == TypeTags.UNION_TAG) { - for (Type elementType : ((UnionType) referredType).getMemberTypes()) { - if (elementType.getTag() == TypeTags.MAP_TAG) { - return elementType; + public static Type getMutableType(Type dataType) { + if (dataType.getTag() == TypeTags.INTERSECTION_TAG) { + IntersectionType intersectionType = (IntersectionType) dataType; + for (Type type : intersectionType.getConstituentTypes()) { + Type referredType = TypeUtils.getImpliedType(type); + if (referredType.getTag() == TypeTags.UNION_TAG) { + for (Type elementType : ((UnionType) referredType).getMemberTypes()) { + if (elementType.getTag() == TypeTags.MAP_TAG) { + return elementType; + } } } + if (TypeUtils.getImpliedType(intersectionType.getEffectiveType()).getTag() == referredType.getTag()) { + return referredType; + } } - if (TypeUtils.getImpliedType(intersectionType.getEffectiveType()).getTag() == referredType.getTag()) { - return referredType; - } + } else { + return dataType; } - throw new IllegalStateException("Unsupported intersection type found: " + intersectionType); + throw new IllegalStateException("Unsupported intersection type found."); } } diff --git a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeRecordVisitor.java b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeRecordVisitor.java index 8ecb1da..a8bb9f6 100644 --- a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeRecordVisitor.java +++ b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeRecordVisitor.java @@ -5,7 +5,6 @@ import io.ballerina.lib.avro.deserialize.RecordDeserializer; import io.ballerina.lib.avro.deserialize.StringDeserializer; import io.ballerina.runtime.api.creators.ValueCreator; -import io.ballerina.runtime.api.types.IntersectionType; import io.ballerina.runtime.api.types.RecordType; import io.ballerina.runtime.api.types.Type; import io.ballerina.runtime.api.utils.StringUtils; @@ -26,7 +25,7 @@ public BMap visit(RecordDeserializer recordDeserializer, Generi Type originalType = recordDeserializer.getType(); Type type = recordDeserializer.getType(); Schema schema = recordDeserializer.getSchema(); - BMap avroRecord = createAvroRecord(type); + BMap avroRecord = ValueCreator.createRecordValue((RecordType) getMutableType(type)); for (Schema.Field field : schema.getFields()) { Object fieldData = rec.get(field.name()); switch (field.schema().getType()) { @@ -57,13 +56,6 @@ public BMap visit(RecordDeserializer recordDeserializer, Generi return avroRecord; } - private BMap createAvroRecord(Type type) { - if (type instanceof IntersectionType) { - type = getMutableType((IntersectionType) type); - } - return ValueCreator.createRecordValue((RecordType) type); - } - private void processMapField(BMap avroRecord, Schema.Field field, Object fieldData) throws Exception { Type mapType = extractMapType(avroRecord.getType()); diff --git a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeVisitor.java b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeVisitor.java index 3460928..a7ee89b 100644 --- a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeVisitor.java +++ b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/DeserializeVisitor.java @@ -51,13 +51,6 @@ import java.util.List; import java.util.Map; -import static io.ballerina.lib.avro.Utils.ARRAY_TYPE; -import static io.ballerina.lib.avro.Utils.BOOLEAN_TYPE; -import static io.ballerina.lib.avro.Utils.FLOAT_TYPE; -import static io.ballerina.lib.avro.Utils.INTEGER_TYPE; -import static io.ballerina.lib.avro.Utils.RECORD_TYPE; -import static io.ballerina.lib.avro.Utils.REFERENCE_TYPE; -import static io.ballerina.lib.avro.Utils.STRING_TYPE; import static io.ballerina.lib.avro.Utils.getMutableType; import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processArrayField; import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processBytesField; @@ -188,22 +181,22 @@ public BArray visit(UnionDeserializer unionDeserializer, GenericData.Array { + case TypeTags.STRING_TAG -> { return visitStringArray(data); } - case FLOAT_TYPE -> { + case TypeTags.FLOAT_TAG -> { return visitDoubleArray(data); } - case BOOLEAN_TYPE -> { + case TypeTags.BOOLEAN_TAG -> { return visitBooleanArray(data); } - case INTEGER_TYPE -> { + case TypeTags.INT_TAG -> { return visitIntegerArray(data, schema); } - case RECORD_TYPE -> { + case TypeTags.RECORD_TYPE_TAG -> { return visitRecordArray(data, type, schema); } - case ARRAY_TYPE -> { + case TypeTags.ARRAY_TAG -> { return visitUnionArray(data, (ArrayType) type, schema); } default -> { @@ -234,14 +227,14 @@ public BArray visit(RecordDeserializer recordDeserializer, GenericData.Array { + case TypeTags.ARRAY_TAG -> { for (Object datum : data) { Type fieldType = ((ArrayType) type).getElementType().getCachedReferredType(); RecordDeserializer recordDes = new RecordDeserializer(schema.getElementType(), fieldType); recordList.add(recordDes.visit(this, (GenericRecord) datum)); } } - case REFERENCE_TYPE -> { + case TypeTags.TYPE_REFERENCED_TYPE_TAG -> { for (Object datum : data) { Type fieldType = ((ReferenceType) type).getReferredType(); RecordDeserializer recordDes = new RecordDeserializer(schema.getElementType(), fieldType); @@ -249,15 +242,11 @@ public BArray visit(RecordDeserializer recordDeserializer, GenericData.Array createAvroRecord(Type type) { - if (type instanceof IntersectionType) { - type = getMutableType((IntersectionType) type); - } - return ValueCreator.createRecordValue((RecordType) type); + return ValueCreator.createRecordValue((RecordType) getMutableType(type)); } private void processMaps(BMap avroRecord, Schema schema, @@ -405,7 +394,6 @@ public BString visitString(Object data) { public static Type extractMapType(Type type) throws Exception { Type mapType = type; - assert type instanceof RecordType; if (type.getTag() != TypeTags.RECORD_TYPE_TAG) { throw new Exception("Type is not a record type."); } @@ -445,7 +433,7 @@ public static RecordType extractRecordType(RecordType type) { case TypeTags.RECORD_TYPE_TAG -> recType = (RecordType) fieldType; case TypeTags.INTERSECTION_TAG -> { - Type getType = getMutableType((IntersectionType) fieldType); + Type getType = getMutableType(fieldType); if (getType.getTag() == TypeTags.RECORD_TYPE_TAG) { recType = (RecordType) getType; } diff --git a/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java b/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java index a13acc5..99a7ed2 100644 --- a/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java +++ b/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java @@ -28,6 +28,7 @@ import io.ballerina.lib.avro.serialize.UnionSerializer; import io.ballerina.lib.avro.serialize.visitor.array.ArrayVisitorFactory; import io.ballerina.lib.avro.serialize.visitor.array.IArrayVisitor; +import io.ballerina.runtime.api.TypeTags; import io.ballerina.runtime.api.types.Type; import io.ballerina.runtime.api.utils.StringUtils; import io.ballerina.runtime.api.utils.TypeUtils; @@ -44,13 +45,6 @@ import java.util.Map; import java.util.Objects; -import static io.ballerina.lib.avro.Utils.ARRAY_TYPE; -import static io.ballerina.lib.avro.Utils.FLOAT_TYPE; -import static io.ballerina.lib.avro.Utils.INTEGER_TYPE; -import static io.ballerina.lib.avro.Utils.MAP_TYPE; -import static io.ballerina.lib.avro.Utils.RECORD_TYPE; -import static io.ballerina.lib.avro.Utils.STRING_TYPE; - public class SerializeVisitor implements ISerializeVisitor { public Serializer createSerializer(Schema schema) { @@ -166,37 +160,24 @@ public Object visitUnion(UnionSerializer unionSerializer, Object data) throws Ex Schema fieldSchema = unionSerializer.getSchema(); Type typeName = TypeUtils.getType(data); switch (typeName.getTag()) { - case STRING_TYPE -> { + case TypeTags.STRING_TAG -> { return fieldSchema.getTypes().stream() .filter(type -> type.getType().equals(Schema.Type.ENUM)) .findFirst() .map(type -> visit(new EnumSerializer(type), data)) .orElse(visit(new PrimitiveDeserializer(fieldSchema), data.toString())); } - case ARRAY_TYPE -> { - for (Schema schema : fieldSchema.getTypes()) { - switch (schema.getType()) { - case BYTES -> { - return new PrimitiveDeserializer(schema).convert(this, data); - } - case FIXED -> { - return new FixedSerializer(schema).convert(this, data); - } - case ARRAY -> { - return new ArraySerializer(schema).convert(this, data); - } - } - } - return new ArraySerializer(fieldSchema).convert(this, data); + case TypeTags.ARRAY_TAG -> { + return visitUnionArrays(data, fieldSchema); } - case MAP_TYPE -> { + case TypeTags.MAP_TAG -> { return new MapSerializer(fieldSchema).convert(this, data); } - case RECORD_TYPE -> { - return new RecordSerializer(getRecordSchema(Schema.Type.RECORD, fieldSchema.getTypes())) - .convert(this, data); + case TypeTags.RECORD_TYPE_TAG -> { + return new RecordSerializer(getRecordSchema(Schema.Type.RECORD, + fieldSchema.getTypes())).convert(this, data); } - case INTEGER_TYPE -> { + case TypeTags.INT_TAG -> { return fieldSchema.getTypes().stream() .filter(schema -> schema.getType().equals(Schema.Type.INT)) .findFirst() @@ -204,13 +185,12 @@ public Object visitUnion(UnionSerializer unionSerializer, Object data) throws Ex .orElse(data); } - case FLOAT_TYPE -> { + case TypeTags.FLOAT_TAG -> { return fieldSchema.getTypes().stream() .filter(schema -> schema.getType().equals(Schema.Type.FLOAT)) .findFirst() .map(schema -> new PrimitiveDeserializer(schema).convert(this, data)) .orElse(data); - } default -> { return data; @@ -218,6 +198,23 @@ public Object visitUnion(UnionSerializer unionSerializer, Object data) throws Ex } } + private Object visitUnionArrays(Object data, Schema fieldSchema) { + for (Schema schema : fieldSchema.getTypes()) { + switch (schema.getType()) { + case BYTES -> { + return new PrimitiveDeserializer(schema).convert(this, data); + } + case FIXED -> { + return new FixedSerializer(schema).convert(this, data); + } + case ARRAY -> { + return new ArraySerializer(schema).convert(this, data); + } + } + } + return new ArraySerializer(fieldSchema).convert(this, data); + } + public static Schema getRecordSchema(Schema.Type givenType, List schemas) { for (Schema schema: schemas) { if (schema.getType().equals(Schema.Type.UNION)) {