Skip to content

Commit

Permalink
Introduce CEL as a replacement for Groovy message filters
Browse files Browse the repository at this point in the history
This is a naive and intermediate implementation that doesn't address CEL null-check "problems"
  • Loading branch information
DementevNikita committed Feb 4, 2024
1 parent c75c5cc commit 6127eee
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 131 deletions.
15 changes: 4 additions & 11 deletions api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,6 @@
<artifactId>spring-security-ldap</artifactId>
</dependency>


<dependency>
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy-jsr223</artifactId>
<version>${groovy.version}</version>
</dependency>
<dependency>
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy-json</artifactId>
<version>${groovy.version}</version>
</dependency>
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-java</artifactId>
Expand All @@ -260,6 +249,10 @@
<artifactId>spring-boot-devtools</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.cel</groupId>
<artifactId>cel</artifactId>
</dependency>

</dependencies>

Expand Down
178 changes: 115 additions & 63 deletions api/src/main/java/io/kafbat/ui/emitter/MessageFilters.java
Original file line number Diff line number Diff line change
@@ -1,98 +1,150 @@
package io.kafbat.ui.emitter;

import groovy.json.JsonSlurper;
import io.kafbat.ui.exception.ValidationException;
import static java.util.Collections.emptyMap;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelOptions;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelValidationResult;
import dev.cel.common.types.MapType;
import dev.cel.common.types.SimpleType;
import dev.cel.compiler.CelCompiler;
import dev.cel.compiler.CelCompilerBuilder;
import dev.cel.compiler.CelCompilerFactory;
import dev.cel.parser.CelStandardMacro;
import dev.cel.runtime.CelEvaluationException;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntimeFactory;
import io.kafbat.ui.exception.CelException;
import io.kafbat.ui.model.MessageFilterTypeDTO;
import io.kafbat.ui.model.TopicMessageDTO;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import javax.script.CompiledScript;
import javax.script.ScriptEngineManager;
import javax.script.ScriptException;
import lombok.SneakyThrows;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.codehaus.groovy.jsr223.GroovyScriptEngineImpl;

@Slf4j
@UtilityClass
public class MessageFilters {
private static final CelCompiler CEL_COMPILER = createCompiler();
private static final CelRuntime CEL_RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build();

private static GroovyScriptEngineImpl GROOVY_ENGINE;

private MessageFilters() {
}
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

public static Predicate<TopicMessageDTO> createMsgFilter(String query, MessageFilterTypeDTO type) {
switch (type) {
case STRING_CONTAINS:
return containsStringFilter(query);
case GROOVY_SCRIPT:
return groovyScriptFilter(query);
default:
throw new IllegalStateException("Unknown query type: " + type);
}
return switch (type) {
case STRING_CONTAINS -> containsStringFilter(query);
case CEL_SCRIPT -> celScriptFilter(query);
};
}

static Predicate<TopicMessageDTO> containsStringFilter(String string) {
return msg -> StringUtils.contains(msg.getKey(), string)
|| StringUtils.contains(msg.getContent(), string);
}

static Predicate<TopicMessageDTO> groovyScriptFilter(String script) {
var engine = getGroovyEngine();
var compiledScript = compileScript(engine, script);
var jsonSlurper = new JsonSlurper();
return new Predicate<TopicMessageDTO>() {
@SneakyThrows
@Override
public boolean test(TopicMessageDTO msg) {
var bindings = engine.createBindings();
bindings.put("partition", msg.getPartition());
bindings.put("offset", msg.getOffset());
bindings.put("timestampMs", msg.getTimestamp().toInstant().toEpochMilli());
bindings.put("keyAsText", msg.getKey());
bindings.put("valueAsText", msg.getContent());
bindings.put("headers", msg.getHeaders());
bindings.put("key", parseToJsonOrReturnAsIs(jsonSlurper, msg.getKey()));
bindings.put("value", parseToJsonOrReturnAsIs(jsonSlurper, msg.getContent()));
var result = compiledScript.eval(bindings);
if (result instanceof Boolean) {
return (Boolean) result;
} else {
throw new ValidationException(
"Unexpected script result: %s, Boolean should be returned instead".formatted(result));
}
static Predicate<TopicMessageDTO> celScriptFilter(String script) {
CelValidationResult celValidationResult = CEL_COMPILER.compile(script);
if (celValidationResult.hasError()) {
throw new CelException(script, celValidationResult.getErrorString());
}

try {
CelAbstractSyntaxTree ast = celValidationResult.getAst();
CelRuntime.Program program = CEL_RUNTIME.createProgram(ast);

return createPredicate(script, program);
} catch (CelValidationException | CelEvaluationException e) {
throw new CelException(script, e);
}
}

private static Predicate<TopicMessageDTO> createPredicate(String originalScript, CelRuntime.Program program) {
return topicMessage -> {
Object programResult;
try {
programResult = program.eval(recordToArgs(topicMessage));
} catch (CelEvaluationException e) {
throw new CelException(originalScript, e);
}

if (programResult instanceof Boolean isMessageMatched) {
return isMessageMatched;
}

throw new CelException(
originalScript,
"Unexpected script result, boolean should be returned instead. Script output: %s".formatted(programResult)
);
};
}

@Nullable
private static Object parseToJsonOrReturnAsIs(JsonSlurper parser, @Nullable String str) {
if (str == null) {
return null;
private static Map<String, Object> recordToArgs(TopicMessageDTO topicMessage) {
Map<String, Object> args = new HashMap<>();

args.put("partition", topicMessage.getPartition());
args.put("offset", topicMessage.getOffset());

if (topicMessage.getTimestamp() != null) {
args.put("timestampMs", topicMessage.getTimestamp().toInstant().toEpochMilli());
}
try {
return parser.parseText(str);
} catch (Exception e) {
return str;

args.put("keyAsText", Objects.requireNonNullElse(topicMessage.getKey(), ""));
args.put("valueAsText", Objects.requireNonNullElse(topicMessage.getContent(), ""));

if (topicMessage.getKey() != null) {
args.put("key", parseToJsonOrReturnAsIs(topicMessage.getKey()));
} else {
args.put("key", emptyMap());
}
}

private static synchronized GroovyScriptEngineImpl getGroovyEngine() {
// it is pretty heavy object, so initializing it on-demand
if (GROOVY_ENGINE == null) {
GROOVY_ENGINE = (GroovyScriptEngineImpl)
new ScriptEngineManager().getEngineByName("groovy");
if (topicMessage.getContent() != null) {
args.put("value", parseToJsonOrReturnAsIs(topicMessage.getContent()));
} else {
args.put("value", emptyMap());
}
return GROOVY_ENGINE;

args.put("headers", Objects.requireNonNullElse(topicMessage.getHeaders(), emptyMap()));

return args;
}

private static CompiledScript compileScript(GroovyScriptEngineImpl engine, String script) {
private static CelCompiler createCompiler() {
CelCompilerBuilder celCompilerBuilder = CelCompilerFactory.standardCelCompilerBuilder()
.setOptions(CelOptions.DEFAULT)
.setStandardMacros(CelStandardMacro.STANDARD_MACROS);

celCompilerBuilder.addVar("partition", SimpleType.INT);
celCompilerBuilder.addVar("offset", SimpleType.INT);
celCompilerBuilder.addVar("timestampMs", SimpleType.INT);
celCompilerBuilder.addVar("keyAsText", SimpleType.STRING);
celCompilerBuilder.addVar("valueAsText", SimpleType.STRING);
celCompilerBuilder.addVar("headers", MapType.create(SimpleType.STRING, SimpleType.STRING));
celCompilerBuilder.addVar("key", SimpleType.DYN);
celCompilerBuilder.addVar("value", SimpleType.DYN);

return celCompilerBuilder
.setResultType(SimpleType.BOOL)
.build();
}

@Nullable
private static Object parseToJsonOrReturnAsIs(@Nullable String str) {
if (str == null) {
return null;
}

try {
return engine.compile(script);
} catch (ScriptException e) {
throw new ValidationException("Script syntax error: " + e.getMessage());
return OBJECT_MAPPER.readValue(str, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
return str;
}
}

}
22 changes: 22 additions & 0 deletions api/src/main/java/io/kafbat/ui/exception/CelException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.kafbat.ui.exception;

public class CelException extends CustomBaseException {
private String celOriginalExpression;

public CelException(String celOriginalExpression, String errorMessage) {
super("CEL error. Original expression: %s. Error message: %s".formatted(celOriginalExpression, errorMessage));

this.celOriginalExpression = celOriginalExpression;
}

public CelException(String celOriginalExpression, Throwable celThrowable) {
super("CEL error. Original expression: %s".formatted(celOriginalExpression), celThrowable);

this.celOriginalExpression = celOriginalExpression;
}

@Override
public ErrorCode getErrorCode() {
return ErrorCode.CEL_ERROR;
}
}
1 change: 1 addition & 0 deletions api/src/main/java/io/kafbat/ui/exception/ErrorCode.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public enum ErrorCode {
SCHEMA_NOT_DELETED(4017, HttpStatus.INTERNAL_SERVER_ERROR),
TOPIC_ANALYSIS_ERROR(4018, HttpStatus.BAD_REQUEST),
FILE_UPLOAD_EXCEPTION(4019, HttpStatus.INTERNAL_SERVER_ERROR),
CEL_ERROR(4020, HttpStatus.BAD_REQUEST),
;

static {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static SmartFilterTestExecutionResultDTO execSmartFilterTest(SmartFilterT
try {
predicate = MessageFilters.createMsgFilter(
execData.getFilterCode(),
MessageFilterTypeDTO.GROOVY_SCRIPT
MessageFilterTypeDTO.CEL_SCRIPT
);
} catch (Exception e) {
log.info("Smart filter '{}' compilation error", execData.getFilterCode(), e);
Expand Down
8 changes: 6 additions & 2 deletions api/src/test/java/io/kafbat/ui/AbstractIntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ public abstract class AbstractIntegrationTest {
public static final String LOCAL = "local";
public static final String SECOND_LOCAL = "secondLocal";

private static final String CONFLUENT_PLATFORM_VERSION = "7.2.1"; // Append ".arm64" for a local run
private static final boolean IS_ARM =
System.getProperty("os.arch").contains("arm") || System.getProperty("os.arch").contains("aarch64");

private static final String CONFLUENT_PLATFORM_VERSION = IS_ARM ? "7.2.1.arm64" : "7.2.1";

public static final KafkaContainer kafka = new KafkaContainer(
DockerImageName.parse("confluentinc/cp-kafka").withTag(CONFLUENT_PLATFORM_VERSION))
Expand Down Expand Up @@ -71,7 +74,8 @@ public void initialize(@NotNull ConfigurableApplicationContext context) {
System.setProperty("kafka.clusters.0.name", LOCAL);
System.setProperty("kafka.clusters.0.bootstrapServers", kafka.getBootstrapServers());
// List unavailable hosts to verify failover
System.setProperty("kafka.clusters.0.schemaRegistry", String.format("http://localhost:%1$s,http://localhost:%1$s,%2$s",
System.setProperty("kafka.clusters.0.schemaRegistry",
String.format("http://localhost:%1$s,http://localhost:%1$s,%2$s",
TestSocketUtils.findAvailableTcpPort(), schemaRegistry.getUrl()));
System.setProperty("kafka.clusters.0.kafkaConnect.0.name", "kafka-connect");
System.setProperty("kafka.clusters.0.kafkaConnect.0.userName", "kafka-connect");
Expand Down
Loading

0 comments on commit 6127eee

Please sign in to comment.