Skip to content

Commit

Permalink
[#190] fix(core): Create serde class lazily (#192)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Lazily create the serde class in the classloader when serializing and
deserializing

### Why are the changes needed?

Class Not Found Exception will be thrown when serializing or
deserializing catalog entity.

Fix: #190 

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Existing UTs
  • Loading branch information
mchades authored Aug 10, 2023
1 parent 12e7186 commit 46b1cb6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 103 deletions.
114 changes: 51 additions & 63 deletions core/src/main/java/com/datastrato/graviton/proto/ProtoEntitySerDe.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import com.google.protobuf.Message;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;

public class ProtoEntitySerDe implements EntitySerDe {

Expand All @@ -38,100 +37,89 @@ public class ProtoEntitySerDe implements EntitySerDe {

private final Map<Class<? extends Entity>, Class<? extends Message>> entityToProto;

private final Map<Class<? extends Message>, Class<? extends Entity>> protoToEntity;

public ProtoEntitySerDe() throws IOException {
ClassLoader loader =
Optional.ofNullable(Thread.currentThread().getContextClassLoader())
.orElse(getClass().getClassLoader());

// TODO. This potentially has issues in creating serde objects, because the class load here
// may have no context for entities which are implemented in the specific catalog module. We
// should lazily create the serde class in the classloader when serializing and deserializing.
public ProtoEntitySerDe() {
this.entityToSerDe = Maps.newHashMap();
for (Map.Entry<String, String> entry : ENTITY_TO_SERDE.entrySet()) {
String key = entry.getKey();
String s = entry.getValue();
Class<? extends Entity> entityClass = (Class<? extends Entity>) loadClass(key, loader);
Class<? extends ProtoSerDe<? extends Entity, ? extends Message>> serdeClass =
(Class<? extends ProtoSerDe<? extends Entity, ? extends Message>>) loadClass(s, loader);

try {
ProtoSerDe<? extends Entity, ? extends Message> serde = serdeClass.newInstance();
entityToSerDe.put(entityClass, serde);
} catch (Exception exception) {
throw new IOException("Failed to instantiate serde class " + s, exception);
}
}

this.entityToProto = Maps.newHashMap();
this.protoToEntity = Maps.newHashMap();
for (Map.Entry<String, String> entry : ENTITY_TO_PROTO.entrySet()) {
String e = entry.getKey();
String p = entry.getValue();
Class<? extends Entity> entityClass = (Class<? extends Entity>) loadClass(e, loader);
Class<? extends Message> protoClass = (Class<? extends Message>) loadClass(p, loader);
entityToProto.put(entityClass, protoClass);
protoToEntity.put(protoClass, entityClass);
}
}

@Override
public <T extends Entity> byte[] serialize(T t) throws IOException {
Any any = Any.pack(toProto(t));
Any any = Any.pack(toProto(t, Thread.currentThread().getContextClassLoader()));
return any.toByteArray();
}

@Override
public <T extends Entity> T deserialize(byte[] bytes, Class<T> clazz, ClassLoader classLoader)
throws IOException {
Any any = Any.parseFrom(bytes);
Class<? extends Message> protoClass = getProtoClass(clazz, classLoader);

if (!entityToSerDe.containsKey(clazz) || !entityToProto.containsKey(clazz)) {
throw new IOException("No proto and serde class found for entity " + clazz.getName());
}

if (!any.is(entityToProto.get(clazz))) {
if (!any.is(protoClass)) {
throw new IOException("Invalid proto for entity " + clazz.getName());
}

try {
Class<? extends Message> protoClazz = entityToProto.get(clazz);
Message anyMessage = any.unpack(protoClazz);
return fromProto(anyMessage);
} catch (Exception e) {
throw new IOException("Failed to deserialize entity " + clazz.getName(), e);
}
Message anyMessage = any.unpack(protoClass);
return fromProto(anyMessage, clazz, classLoader);
}

public <T extends Entity, M extends Message> M toProto(T t) throws IOException {
if (!entityToSerDe.containsKey(t.getClass())) {
throw new IOException("No serde found for entity " + t.getClass().getName());
private <T extends Entity, M extends Message> ProtoSerDe<T, M> getProtoSerde(
Class<T> entityClass, ClassLoader classLoader) throws IOException {
if (!ENTITY_TO_SERDE.containsKey(entityClass.getCanonicalName())
|| ENTITY_TO_SERDE.get(entityClass.getCanonicalName()) == null) {
throw new IOException("No serde found for entity " + entityClass.getCanonicalName());
}

ProtoSerDe<T, M> protoSerDe = (ProtoSerDe<T, M>) entityToSerDe.get(t.getClass());
return protoSerDe.serialize(t);
return (ProtoSerDe<T, M>)
entityToSerDe.computeIfAbsent(
entityClass,
k -> {
try {
Class<? extends ProtoSerDe<? extends Entity, ? extends Message>> serdeClazz =
(Class<? extends ProtoSerDe<? extends Entity, ? extends Message>>)
loadClass(ENTITY_TO_SERDE.get(k.getCanonicalName()), classLoader);
return serdeClazz.newInstance();
} catch (Exception e) {
throw new RuntimeException(
"Failed to instantiate serde class " + k.getCanonicalName(), e);
}
});
}

public <T extends Entity, M extends Message> T fromProto(M m) throws IOException {
if (!protoToEntity.containsKey(m.getClass())) {
throw new IOException("No entity class found for proto " + m.getClass().getName());
private Class<? extends Message> getProtoClass(
Class<? extends Entity> entityClass, ClassLoader classLoader) throws IOException {
if (!ENTITY_TO_PROTO.containsKey(entityClass.getCanonicalName())
|| ENTITY_TO_PROTO.get(entityClass.getCanonicalName()) == null) {
throw new IOException("No proto class found for entity " + entityClass.getCanonicalName());
}
Class<? extends Entity> entityClass = protoToEntity.get(m.getClass());
return entityToProto.computeIfAbsent(
entityClass,
k -> {
try {
return (Class<? extends Message>)
loadClass(ENTITY_TO_PROTO.get(k.getCanonicalName()), classLoader);
} catch (Exception e) {
throw new RuntimeException("Failed to create proto class " + k.getCanonicalName(), e);
}
});
}

if (!entityToSerDe.containsKey(entityClass)) {
throw new IOException("No serde found for entity " + entityClass.getName());
}
private <T extends Entity, M extends Message> M toProto(T t, ClassLoader classLoader)
throws IOException {
ProtoSerDe<T, M> protoSerDe = (ProtoSerDe<T, M>) getProtoSerde(t.getClass(), classLoader);
return protoSerDe.serialize(t);
}

ProtoSerDe<T, M> protoSerDe = (ProtoSerDe<T, M>) entityToSerDe.get(entityClass);
private <T extends Entity, M extends Message> T fromProto(
M m, Class<T> entityClass, ClassLoader classLoader) throws IOException {
ProtoSerDe<T, Message> protoSerDe = getProtoSerde(entityClass, classLoader);
return protoSerDe.deserialize(m);
}

private Class<?> loadClass(String className, ClassLoader classLoader) throws IOException {
try {
return Class.forName(className, true, classLoader);
} catch (Exception e) {
throw new IOException("Failed to load class " + className, e);
throw new IOException(
"Failed to load class " + className + " with classLoader " + classLoader, e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ public void testAuditInfoSerDe() throws IOException {

ProtoEntitySerDe protoEntitySerDe = (ProtoEntitySerDe) entitySerDe;

AuditInfo auditInfoProto = protoEntitySerDe.toProto(auditInfo);
Assertions.assertEquals(creator, auditInfoProto.getCreator());
Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto.getCreateTime()));
Assertions.assertEquals(modifier, auditInfoProto.getLastModifier());
Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto.getLastModifiedTime()));

com.datastrato.graviton.meta.AuditInfo auditInfoFromProto =
protoEntitySerDe.fromProto(auditInfoProto);
Assertions.assertEquals(auditInfo, auditInfoFromProto);
byte[] bytes = protoEntitySerDe.serialize(auditInfo);
com.datastrato.graviton.meta.AuditInfo auditInfoFromBytes =
protoEntitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class);
Assertions.assertEquals(auditInfo, auditInfoFromBytes);

// Test with optional fields
com.datastrato.graviton.meta.AuditInfo auditInfo1 =
Expand All @@ -51,19 +46,10 @@ public void testAuditInfoSerDe() throws IOException {
.withCreateTime(now)
.build();

AuditInfo auditInfoProto1 = protoEntitySerDe.toProto(auditInfo1);

Assertions.assertEquals(creator, auditInfoProto1.getCreator());
Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto1.getCreateTime()));

com.datastrato.graviton.meta.AuditInfo auditInfoFromProto1 =
protoEntitySerDe.fromProto(auditInfoProto1);
Assertions.assertEquals(auditInfo1, auditInfoFromProto1);

// Test from/to bytes
byte[] bytes = entitySerDe.serialize(auditInfo1);
com.datastrato.graviton.meta.AuditInfo auditInfoFromBytes =
entitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class);
bytes = protoEntitySerDe.serialize(auditInfo1);
auditInfoFromBytes =
protoEntitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class);
Assertions.assertEquals(auditInfo1, auditInfoFromBytes);
}

Expand Down Expand Up @@ -94,12 +80,6 @@ public void testEntitiesSerDe() throws IOException {

ProtoEntitySerDe protoEntitySerDe = (ProtoEntitySerDe) entitySerDe;

Metalake metalakeProto = protoEntitySerDe.toProto(metalake);
Assertions.assertEquals(props, metalakeProto.getPropertiesMap());
com.datastrato.graviton.meta.BaseMetalake metalakeFromProto =
protoEntitySerDe.fromProto(metalakeProto);
Assertions.assertEquals(metalake, metalakeFromProto);

byte[] metalakeBytes = protoEntitySerDe.serialize(metalake);
com.datastrato.graviton.meta.BaseMetalake metalakeFromBytes =
protoEntitySerDe.deserialize(
Expand All @@ -115,15 +95,10 @@ public void testEntitiesSerDe() throws IOException {
.withVersion(version)
.build();

Metalake metalakeProto1 = protoEntitySerDe.toProto(metalake1);
Assertions.assertEquals(0, metalakeProto1.getPropertiesCount());
com.datastrato.graviton.meta.BaseMetalake metalakeFromProto1 =
protoEntitySerDe.fromProto(metalakeProto1);
Assertions.assertEquals(metalake1, metalakeFromProto1);

byte[] metalakeBytes1 = entitySerDe.serialize(metalake1);
byte[] metalakeBytes1 = protoEntitySerDe.serialize(metalake1);
com.datastrato.graviton.meta.BaseMetalake metalakeFromBytes1 =
entitySerDe.deserialize(metalakeBytes1, com.datastrato.graviton.meta.BaseMetalake.class);
protoEntitySerDe.deserialize(
metalakeBytes1, com.datastrato.graviton.meta.BaseMetalake.class);
Assertions.assertEquals(metalake1, metalakeFromBytes1);

// Test CatalogEntity
Expand All @@ -141,11 +116,6 @@ public void testEntitiesSerDe() throws IOException {
.withAuditInfo(auditInfo)
.build();

Catalog catalogProto = protoEntitySerDe.toProto(catalogEntity);
com.datastrato.graviton.meta.CatalogEntity catalogEntityFromProto =
protoEntitySerDe.fromProto(catalogProto);
Assertions.assertEquals(catalogEntity, catalogEntityFromProto);

byte[] catalogBytes = protoEntitySerDe.serialize(catalogEntity);
com.datastrato.graviton.meta.CatalogEntity catalogEntityFromBytes =
protoEntitySerDe.deserialize(
Expand Down

0 comments on commit 46b1cb6

Please sign in to comment.