Skip to content

Commit

Permalink
Move utility functions of record APIs to a seperate class
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuvindu committed May 9, 2024
1 parent a5fef53 commit a4d9890
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.apache.avro.generic.GenericData;

public class ByteDeserializer extends Deserializer {

@Override
public Object visit(DeserializeVisitor visitor, Object data) throws Exception {
return visitor.visitBytes(data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.apache.avro.generic.GenericRecord;

import java.nio.ByteBuffer;
import java.util.Map;

import static io.ballerina.lib.avro.Utils.getMutableType;
import static io.ballerina.lib.avro.deserialize.visitor.UnionRecordUtils.visitUnionRecords;
Expand Down Expand Up @@ -69,7 +68,7 @@ private void processMapField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
Type mapType = extractMapType(avroRecord.getType());
MapDeserializer mapDeserializer = new MapDeserializer(field.schema(), mapType);
Object fieldValue = mapDeserializer.visit(this, (Map<String, Object>) fieldData);
Object fieldValue = mapDeserializer.visit(this, fieldData);
avroRecord.put(StringUtils.fromString(field.name()), fieldValue);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import io.ballerina.lib.avro.deserialize.Deserializer;
import io.ballerina.lib.avro.deserialize.EnumDeserializer;
import io.ballerina.lib.avro.deserialize.FixedDeserializer;
import io.ballerina.lib.avro.deserialize.GenericDeserializer;
import io.ballerina.lib.avro.deserialize.MapDeserializer;
import io.ballerina.lib.avro.deserialize.PrimitiveDeserializer;
import io.ballerina.lib.avro.deserialize.RecordDeserializer;
import io.ballerina.lib.avro.deserialize.StringDeserializer;
import io.ballerina.lib.avro.deserialize.UnionDeserializer;
import io.ballerina.runtime.api.TypeTags;
import io.ballerina.runtime.api.creators.ValueCreator;
Expand All @@ -36,6 +35,7 @@
import io.ballerina.runtime.api.types.RecordType;
import io.ballerina.runtime.api.types.ReferenceType;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.utils.StringUtils;
import io.ballerina.runtime.api.utils.TypeUtils;
import io.ballerina.runtime.api.utils.ValueUtils;
import io.ballerina.runtime.api.values.BArray;
Expand All @@ -59,7 +59,12 @@
import static io.ballerina.lib.avro.Utils.REFERENCE_TYPE;
import static io.ballerina.lib.avro.Utils.STRING_TYPE;
import static io.ballerina.lib.avro.Utils.getMutableType;
import static io.ballerina.lib.avro.deserialize.visitor.UnionRecordUtils.visitUnionRecords;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processArrayField;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processBytesField;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processMapField;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processRecordField;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processStringField;
import static io.ballerina.lib.avro.deserialize.visitor.RecordUtils.processUnionField;
import static io.ballerina.runtime.api.utils.StringUtils.fromString;

public class DeserializeVisitor implements IDeserializeVisitor {
Expand All @@ -77,7 +82,7 @@ public static Deserializer createDeserializer(Schema schema, Type type) {
case FIXED ->
new FixedDeserializer(schema, type);
default ->
new GenericDeserializer(schema, type);
new PrimitiveDeserializer(schema, type);
};
}

Expand Down Expand Up @@ -128,29 +133,30 @@ public BMap<BString, Object> visit(MapDeserializer mapDeserializer, Map<String,
case ARRAY ->
processMapArray(avroRecord, schema, (MapType) type, key, (GenericData.Array<Object>) value);
case BYTES ->
avroRecord.put(fromString(key.toString()),
ValueCreator.createArrayValue(((ByteBuffer) value).array()));
avroRecord.put(StringUtils.fromString(key.toString()),
ValueCreator.createArrayValue(((ByteBuffer) value).array()));
case FIXED ->
avroRecord.put(fromString(key.toString()),
ValueCreator.createArrayValue(((GenericFixed) value).bytes()));
avroRecord.put(StringUtils.fromString(key.toString()),
ValueCreator.createArrayValue(((GenericFixed) value).bytes()));
case ENUM, STRING ->
avroRecord.put(fromString(key.toString()),
fromString(value.toString()));
avroRecord.put(StringUtils.fromString(key.toString()),
StringUtils.fromString(value.toString()));
case RECORD ->
processMapRecord(avroRecord, schema, (MapType) type, key, (GenericRecord) value);
case FLOAT ->
avroRecord.put(fromString(key.toString()), Double.parseDouble(value.toString()));
avroRecord.put(StringUtils.fromString(key.toString()),
Double.parseDouble(value.toString()));
case MAP ->
processMaps(avroRecord, schema, (MapType) type, key, (Map<String, Object>) value);
default ->
avroRecord.put(fromString(key.toString()), value);
avroRecord.put(StringUtils.fromString(key.toString()), value);
}
}
return (BMap<BString, Object>) ValueUtils.convert(avroRecord, type);
}

public Object visit(GenericDeserializer genericDeserializer, Object data) {
Schema schema = genericDeserializer.getSchema();
public Object visit(PrimitiveDeserializer primitiveDeserializer, Object data) {
Schema schema = primitiveDeserializer.getSchema();
if (schema.getType().equals(Schema.Type.ARRAY)) {
GenericData.Array<Object> array = (GenericData.Array<Object>) data;
switch (schema.getElementType().getType()) {
Expand All @@ -170,7 +176,7 @@ public Object visit(GenericDeserializer genericDeserializer, Object data) {
return visitBooleanArray(array);
}
default -> {
return visitBytesArray(array, genericDeserializer.getType());
return visitBytesArray(array, primitiveDeserializer.getType());
}
}
} else {
Expand All @@ -195,27 +201,34 @@ public BArray visit(UnionDeserializer unionDeserializer, GenericData.Array<Objec
return visitIntegerArray(data, schema);
}
case RECORD_TYPE -> {
RecordDeserializer recordDeserializer = new RecordDeserializer(schema.getElementType(), type);
return (BArray) recordDeserializer.visit(this, data);
return visitRecordArray(data, type, schema);
}
case ARRAY_TYPE -> {
Object[] objects = new Object[data.size()];
Type elementType = ((ArrayType) type).getElementType();
ArrayDeserializer arrayDeserializer = new ArrayDeserializer(schema.getElementType(), elementType);
int index = 0;
for (Object currentData : data) {
Object deserializedObject = arrayDeserializer.visit(this, (GenericData.Array<Object>) currentData);
objects[index++] = deserializedObject;
}
return ValueCreator.createArrayValue(objects, (ArrayType) type);

return visitUnionArray(data, (ArrayType) type, schema);
}
default -> {
return visitBytes(data);
}
}
}

private BArray visitRecordArray(GenericData.Array<Object> data, Type type, Schema schema) throws Exception {
RecordDeserializer recordDeserializer = new RecordDeserializer(schema.getElementType(), type);
return (BArray) recordDeserializer.visit(this, data);
}

private BArray visitUnionArray(GenericData.Array<Object> data, ArrayType type, Schema schema) throws Exception {
Object[] objects = new Object[data.size()];
Type elementType = type.getElementType();
ArrayDeserializer arrayDeserializer = new ArrayDeserializer(schema.getElementType(), elementType);
int index = 0;
for (Object currentData : data) {
Object deserializedObject = arrayDeserializer.visit(this, (GenericData.Array<Object>) currentData);
objects[index++] = deserializedObject;
}
return ValueCreator.createArrayValue(objects, type);
}

public BArray visit(RecordDeserializer recordDeserializer, GenericData.Array<Object> data) throws Exception {
List<Object> recordList = new ArrayList<>();
Type type = recordDeserializer.getType();
Expand Down Expand Up @@ -247,47 +260,6 @@ private BMap<BString, Object> createAvroRecord(Type type) {
return ValueCreator.createRecordValue((RecordType) type);
}

private void processMapField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
Type mapType = extractMapType(avroRecord.getType());
MapDeserializer mapDeserializer = new MapDeserializer(field.schema(), mapType);
Object fieldValue = mapDeserializer.visit(this, (Map<String, Object>) fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

private void processArrayField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
ArrayDeserializer arrayDes = new ArrayDeserializer(field.schema(), avroRecord.getType());
Object fieldValue = arrayDes.visit(this, (GenericData.Array<Object>) fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

private void processBytesField(BMap<BString, Object> avroRecord, Schema.Field field, Object fieldData) {
ByteBuffer byteBuffer = (ByteBuffer) fieldData;
Object fieldValue = ValueCreator.createArrayValue(byteBuffer.array());
avroRecord.put(fromString(field.name()), fieldValue);
}

private void processRecordField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
Type recType = extractRecordType((RecordType) avroRecord.getType());
RecordDeserializer recordDes = new RecordDeserializer(field.schema(), recType);
Object fieldValue = recordDes.visit(this, (GenericRecord) fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

private void processStringField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
StringDeserializer stringDes = new StringDeserializer();
Object fieldValue = stringDes.visit(this, fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

private void processUnionField(Type type, BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
visitUnionRecords(type, avroRecord, field, fieldData);
}

private void processMaps(BMap<BString, Object> avroRecord, Schema schema,
MapType type, Object key, Map<String, Object> value) throws Exception {
Schema fieldSchema = schema.getValueType();
Expand Down Expand Up @@ -431,29 +403,31 @@ public BString visitString(Object data) {
return fromString(data.toString());
}

public static Type extractMapType(Type type) {
public static Type extractMapType(Type type) throws Exception {
Type mapType = type;
assert type instanceof RecordType;
if (type.getTag() != TypeTags.RECORD_TYPE_TAG) {
throw new Exception("Type is not a record type.");
}
for (Map.Entry<String, Field> entry : ((RecordType) type).getFields().entrySet()) {
Field fieldValue = entry.getValue();
if (fieldValue != null) {
Type fieldType = fieldValue.getFieldType();
switch (fieldType.getTag()) {
case TypeTags.MAP_TAG:
mapType = fieldType;
break;
case TypeTags.INTERSECTION_TAG:
case TypeTags.MAP_TAG ->
mapType = fieldType;
case TypeTags.INTERSECTION_TAG -> {
Type referredType = getMutableType((IntersectionType) fieldType);
if (referredType.getTag() == TypeTags.MAP_TAG) {
mapType = referredType;
}
break;
default:
}
default -> {
Type referType = TypeUtils.getReferredType(fieldType);
if (referType.getTag() == TypeTags.MAP_TAG) {
mapType = referType;
}
break;
}
}
}
}
Expand All @@ -468,21 +442,20 @@ public static RecordType extractRecordType(RecordType type) {
if (fieldValue != null) {
Type fieldType = fieldValue.getFieldType();
switch (fieldType.getTag()) {
case TypeTags.RECORD_TYPE_TAG:
recType = (RecordType) fieldType;
break;
case TypeTags.INTERSECTION_TAG:
case TypeTags.RECORD_TYPE_TAG ->
recType = (RecordType) fieldType;
case TypeTags.INTERSECTION_TAG -> {
Type getType = getMutableType((IntersectionType) fieldType);
if (getType.getTag() == TypeTags.RECORD_TYPE_TAG) {
recType = (RecordType) getType;
}
break;
default:
}
default -> {
Type referredType = TypeUtils.getReferredType(fieldType);
if (referredType.getTag() == TypeTags.RECORD_TYPE_TAG) {
recType = (RecordType) referredType;
}
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.ballerina.lib.avro.deserialize.visitor;

import io.ballerina.lib.avro.deserialize.ArrayDeserializer;
import io.ballerina.lib.avro.deserialize.MapDeserializer;
import io.ballerina.lib.avro.deserialize.RecordDeserializer;
import io.ballerina.lib.avro.deserialize.StringDeserializer;
import io.ballerina.runtime.api.creators.ValueCreator;
import io.ballerina.runtime.api.types.RecordType;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.values.BMap;
import io.ballerina.runtime.api.values.BString;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;

import java.nio.ByteBuffer;

import static io.ballerina.lib.avro.deserialize.visitor.DeserializeVisitor.extractMapType;
import static io.ballerina.lib.avro.deserialize.visitor.DeserializeVisitor.extractRecordType;
import static io.ballerina.lib.avro.deserialize.visitor.UnionRecordUtils.visitUnionRecords;
import static io.ballerina.runtime.api.utils.StringUtils.fromString;

public class RecordUtils {

public static void processMapField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
Type mapType = extractMapType(avroRecord.getType());
MapDeserializer mapDeserializer = new MapDeserializer(field.schema(), mapType);
Object fieldValue = mapDeserializer.visit(new DeserializeVisitor(), fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

public static void processArrayField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
ArrayDeserializer arrayDes = new ArrayDeserializer(field.schema(), avroRecord.getType());
Object fieldValue = arrayDes.visit(new DeserializeVisitor(), (GenericData.Array<Object>) fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

public static void processBytesField(BMap<BString, Object> avroRecord, Schema.Field field, Object fieldData) {
ByteBuffer byteBuffer = (ByteBuffer) fieldData;
Object fieldValue = ValueCreator.createArrayValue(byteBuffer.array());
avroRecord.put(fromString(field.name()), fieldValue);
}

public static void processRecordField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
Type recType = extractRecordType((RecordType) avroRecord.getType());
RecordDeserializer recordDes = new RecordDeserializer(field.schema(), recType);
Object fieldValue = recordDes.visit(new DeserializeVisitor(), fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

public static void processStringField(BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) {
StringDeserializer stringDes = new StringDeserializer();
Object fieldValue = stringDes.visit(new DeserializeVisitor(), fieldData);
avroRecord.put(fromString(field.name()), fieldValue);
}

public static void processUnionField(Type type, BMap<BString, Object> avroRecord,
Schema.Field field, Object fieldData) throws Exception {
visitUnionRecords(type, avroRecord, field, fieldData);
}
}

0 comments on commit a4d9890

Please sign in to comment.