Skip to content

Commit

Permalink
[GR-41558] Support for lambda class predefinition on Native Image
Browse files Browse the repository at this point in the history
PullRequest: graal/12952
  • Loading branch information
sstanoje committed Feb 5, 2024
2 parents bea58b8 + d84c35d commit 36af33b
Show file tree
Hide file tree
Showing 17 changed files with 422 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*.swo
*.swp
*.zip
*.bgv
.DS_Store
.checkstyle
.classpath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,20 @@ public static String findStableLambdaName(ClassInitializationPlugin cip, Provide
return createStableLambdaName(lambdaType, invokedMethods);
}

/**
* Checks if the passed type is lambda class type based on set flags and the type name.
*
* @param type type to be checked
* @return true if the passed type is lambda type, false otherwise
*/

public static boolean isLambdaType(ResolvedJavaType type) {
String typeName = type.getName();
return type.isFinalFlagSet() && isLambdaName(typeName);
}

public static boolean isLambdaName(String name) {
return name.contains(LAMBDA_CLASS_NAME_SUBSTRING) && lambdaMatcher(name).find();
return isLambdaClassName(name) && lambdaMatcher(name).find();
}

private static String createStableLambdaName(ResolvedJavaType lambdaType, List<ResolvedJavaMethod> targetMethods) {
Expand All @@ -157,17 +164,59 @@ public static String toHex(byte[] data) {
return r.toString();
}

/**
* Hashing a passed string parameter using SHA-1 hashing algorithm.
*
* @param value string to be hashed
* @return hexadecimal hashed value of the passed string parameter
*/
public static String digest(String value) {
return digest(value.getBytes(StandardCharsets.UTF_8));
}

/**
* Hashing a passed byte array parameter using SHA-1 hashing algorithm.
*
* @param bytes byte array to be hashed
* @return hexadecimal hashed value of the passed byte array parameter
*/
public static String digest(byte[] bytes) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-1");
md.update(value.getBytes(StandardCharsets.UTF_8));
md.update(bytes);
return toHex(md.digest());
} catch (NoSuchAlgorithmException ex) {
throw new JVMCIError(ex);
}
}

/**
* Extracts lambda capturing class name from the lambda class name.
*
* @param className name of the lambda class
* @return name of the lambda capturing class
*/
public static String capturingClass(String className) {
return className.split(LambdaUtils.SERIALIZATION_TEST_LAMBDA_CLASS_SPLIT_PATTERN)[0];
}

/**
* Checks if the passed class is lambda class.
*
* @param clazz class to be checked
* @return true if the clazz is lambda class, false instead
*/
public static boolean isLambdaClass(Class<?> clazz) {
return isLambdaClassName(clazz.getName());
}

/**
* Checks if the passed class name is lambda class name.
*
* @param className name of the class
* @return true if the className is lambda class name, false instead
*/
public static boolean isLambdaClassName(String className) {
return className.contains(LAMBDA_CLASS_NAME_SUBSTRING);
}
}
5 changes: 4 additions & 1 deletion sdk/mx.sdk/mx_sdk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def extra_agentlib_options(self, benchmark, args, image_run_args):
The returned options are added to the agentlib:native-image-agent option list.
The config-output-dir is configured by the benchmark runner and cannot be overridden.
"""
return []

# All Renaissance Spark benchmarks require lambda class predefinition, so we need this additional option that
# is used for the class predefinition feature. See GR-37506
return ['experimental-class-define-support'] if (benchmark in ['chi-square', 'gauss-mix', 'movie-lens', 'page-rank']) else []

def extra_profile_run_arg(self, benchmark, args, image_run_args, should_strip_run_args):
"""Returns all arguments passed to the profiling run.
Expand Down
36 changes: 32 additions & 4 deletions substratevm/mx.substratevm/mx_substratevm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,19 @@ def list_jars(path):


force_buildtime_init_slf4j_1_7_73 = '--initialize-at-build-time=org.slf4j,org.apache.log4j'
force_buildtime_init_slf4j_1_7_73_spark = '--initialize-at-build-time=org.apache.logging.slf4j.Log4jLoggerFactory,\
org.apache.logging.slf4j.SLF4JServiceProvider,org.apache.logging.slf4j.Log4jMarkerFactory,org.apache.logging.slf4j.Log4jMDCAdapter,\
org.apache.logging.log4j,org.apache.logging.log4j,org.apache.logging.log4j.core.util.WatchManager,org.apache.logging.log4j.core.config.xml.XmlConfiguration, \
org.apache.logging.log4j.core.config.AbstractConfiguration,org.apache.logging.log4j.util.ServiceLoaderUtil,org.slf4j.LoggerFactory'
force_buildtime_init_netty_4_1_72 = '--initialize-at-build-time=io.netty.util.internal.logging'
force_runtime_init_slf4j_1_7_73 = '--initialize-at-run-time=org.apache.log4j.LogManager'
force_runtime_init_netty_4_1_72 = '--initialize-at-run-time=io.netty.channel.unix,io.netty.channel.epoll,io.netty.handler.codec.http2,io.netty.handler.ssl,io.netty.internal.tcnative,io.netty.util.internal.logging.Log4JLogger'
force_runtime_init_netty_4_1_72_spark = '--initialize-at-run-time=io.netty.buffer.AbstractByteBufAllocator\
io.netty.channel.AbstractChannelHandlerContext,io.netty.channel.ChannelInitializer,io.netty.channel.ChannelOutboundBuffer,\
io.netty.util.internal.SystemPropertyUtil,io.netty.channel.AbstractChannel,io.netty.util.internal.PlatformDependent,\
io.netty.util.internal.InternalThreadLocalMap,io.netty.channel.socket.nio.SelectorProviderUtil,io.netty.util.concurrent.DefaultPromise, \
io.netty.util.NetUtil,io.netty.channel.DefaultChannelPipeline,io.netty.util.concurrent.FastThreadLocalThread,io.netty.util.internal.StringUtil, \
io.netty.util.internal.PlatformDependent0,io.netty.util,io.netty.bootstrap,io.netty.channel,io.netty.buffer,io.netty.resolver,io.netty.handler.codec.CodecOutputList'
_RENAISSANCE_EXTRA_IMAGE_BUILD_ARGS = {
'als' : [
'--report-unsupported-elements-at-runtime',
Expand All @@ -69,7 +81,11 @@ def list_jars(path):
'chi-square' : [
'--report-unsupported-elements-at-runtime',
force_buildtime_init_slf4j_1_7_73,
force_runtime_init_netty_4_1_72
force_buildtime_init_slf4j_1_7_73_spark,
force_buildtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72_spark,
force_runtime_init_slf4j_1_7_73
],
'finagle-chirper' : [
force_buildtime_init_slf4j_1_7_73,
Expand All @@ -87,7 +103,11 @@ def list_jars(path):
'movie-lens' : [
'--report-unsupported-elements-at-runtime',
force_buildtime_init_slf4j_1_7_73,
force_runtime_init_netty_4_1_72
force_buildtime_init_slf4j_1_7_73_spark,
force_buildtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72_spark,
force_runtime_init_slf4j_1_7_73
],
'dec-tree' : [
'--report-unsupported-elements-at-runtime',
Expand All @@ -97,7 +117,11 @@ def list_jars(path):
'page-rank' : [
'--report-unsupported-elements-at-runtime',
force_buildtime_init_slf4j_1_7_73,
force_runtime_init_netty_4_1_72
force_buildtime_init_slf4j_1_7_73_spark,
force_buildtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72_spark,
force_runtime_init_slf4j_1_7_73
],
'naive-bayes' : [
'--report-unsupported-elements-at-runtime',
Expand All @@ -107,7 +131,11 @@ def list_jars(path):
'gauss-mix' : [
'--report-unsupported-elements-at-runtime',
force_buildtime_init_slf4j_1_7_73,
force_runtime_init_netty_4_1_72
force_buildtime_init_slf4j_1_7_73_spark,
force_buildtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72,
force_runtime_init_netty_4_1_72_spark,
force_runtime_init_slf4j_1_7_73
],
'neo4j-analytics': [
'--report-unsupported-elements-at-runtime',
Expand Down
1 change: 1 addition & 0 deletions substratevm/mx.substratevm/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@
"jdk.internal.reflect",
"jdk.internal.vm",
"jdk.internal.util",
"jdk.internal.org.objectweb.asm",
],
"java.management": [
"com.sun.jmx.mbeanserver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ public Set<String> classesSet(boolean packageNameOnly) {
String name = method.getDeclaringClass().toJavaName(true);
if (packageNameOnly) {
name = packagePrefix(name);
if (name.contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING)) {
if (LambdaUtils.isLambdaClassName(name)) {
/* Also strip synthetic package names added for lambdas. */
name = packagePrefix(name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;

import com.oracle.svm.core.jni.headers.JNIMode;
import jdk.graal.compiler.core.common.NumUtil;
import jdk.graal.compiler.java.LambdaUtils;
import org.graalvm.nativeimage.StackValue;
import org.graalvm.nativeimage.UnmanagedMemory;
import org.graalvm.nativeimage.c.function.CEntryPoint;
Expand All @@ -77,6 +80,7 @@
import org.graalvm.nativeimage.c.type.WordPointer;
import org.graalvm.word.WordFactory;

import com.oracle.svm.agent.stackaccess.EagerlyLoadedJavaStackAccess;
import com.oracle.svm.agent.stackaccess.InterceptedState;
import com.oracle.svm.agent.tracing.core.Tracer;
import com.oracle.svm.configure.trace.AccessAdvisor;
Expand All @@ -102,8 +106,6 @@
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiInterface;
import com.oracle.svm.jvmtiagentbase.jvmti.JvmtiLocationFormat;

import jdk.graal.compiler.core.common.NumUtil;

/**
* Intercepts events of interest via breakpoints in Java code.
* <p>
Expand Down Expand Up @@ -961,11 +963,75 @@ private static boolean methodTypeFromDescriptor(JNIEnvironment jni, JNIObjectHan
}

/**
* We have to find a class that captures a lambda function so it can be registered by the agent.
* We have to get a SerializedLambda instance first. After that we get a lambda capturing class
* from that instance using JNIHandleSet#getFieldId to get field id and JNIObjectHandle#invoke
* on to get that field value. We get a name of the capturing class and tell the agent to
* register it.
* This method should be intercepted when we are predefining a lambda class. This is the only
* spot in the lambda-class creation pipeline where we can get lambda-class bytecode so the
* class can be predefined. We do not want to predefine all lambda classes, but only the ones
* that are actually created at runtime, so we have a method that checks wheter the lambda
* should be predefined or not.
*/
private static boolean onMethodHandleClassFileInit(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
String className = Support.fromJniString(jni, getObjectArgument(thread, 1));

if (LambdaUtils.isLambdaClassName(className)) {
if (shouldIgnoreLambdaClassForPredefinition(jni)) {
return true;
}

JNIObjectHandle bytesArray = getObjectArgument(thread, 3);
int length = jniFunctions().getGetArrayLength().invoke(jni, bytesArray);
byte[] data = new byte[length];

CCharPointer bytesArrayCharPointer = jni.getFunctions().getGetByteArrayElements().invoke(jni, bytesArray, WordFactory.nullPointer());
if (bytesArrayCharPointer.isNonNull()) {
try {
CTypeConversion.asByteBuffer(bytesArrayCharPointer, length).get(data);
} finally {
jni.getFunctions().getReleaseByteArrayElements().invoke(jni, bytesArray, bytesArrayCharPointer, JNIMode.JNI_ABORT());
}

className += LambdaUtils.digest(data);
tracer.traceCall("classloading", "onMethodHandleClassFileInit", null, null, null, null, state.getFullStackTraceOrNull(), className, data);
}
}
return true;
}

/**
* This method is used to check whether a lambda class should be predefined or not. Only lambdas
* that are created at runtime should be predefined, and we should ignore the others. This
* method checks if the specific sequence of methods exists in the stacktrace and base on that
* decides if the lambda class should be ignored.
*/
private static boolean shouldIgnoreLambdaClassForPredefinition(JNIEnvironment env) {
JNIMethodId[] stackTraceMethodIds = EagerlyLoadedJavaStackAccess.stackAccessSupplier().get().getFullStackTraceOrNull();
JNIMethodId javaLangInvokeCallSiteMakeSite = agent.handles().getJavaLangInvokeCallSiteMakeSite(env);
JNIMethodId javaLangInvokeMethodHandleNativesLinkCallSiteImpl = agent.handles().getJavaLangInvokeMethodHandleNativesLinkCallSiteImpl(env);
JNIMethodId javaLangInvokeMethodHandleNativesLinkCallSite = agent.handles().getJavaLangInvokeMethodHandleNativesLinkCallSite(env);

/*
* Sequence {@code java.lang.invoke.CallSite.makeSite}, {@code
* java.lang.invoke.MethodHandleNatives.linkCallSiteImpl}, {@code
* java.lang.invoke.MethodHandleNatives.linkCallSite} in the stacktrace indicates that
* lambda class won't be created at runtime on the Native Image, so it should not be
* registered for predefiniton.
*/
for (int i = 0; i < stackTraceMethodIds.length - 2; i++) {
if (stackTraceMethodIds[i] == javaLangInvokeCallSiteMakeSite &&
stackTraceMethodIds[i + 1] == javaLangInvokeMethodHandleNativesLinkCallSiteImpl &&
stackTraceMethodIds[i + 2] == javaLangInvokeMethodHandleNativesLinkCallSite) {
return true;
}
}

return false;
}

/**
* We have to find a class that captures a lambda function, so it can be registered by the
* agent. We have to get a SerializedLambda instance first. After that we get a lambda capturing
* class from that instance using JNIHandleSet#getFieldId to get field id and
* JNIObjectHandle#invoke on to get that field value. We get a name of the capturing class and
* tell the agent to register it.
*/
private static boolean serializedLambdaReadResolve(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle serializedLambdaInstance = getReceiver(thread);
Expand Down Expand Up @@ -1283,6 +1349,12 @@ public static void onVMInit(JvmtiEnv jvmti, JNIEnvironment jni) {
System.arraycopy(BREAKPOINT_SPECIFICATIONS, 0, breakpointSpecifications, 0, BREAKPOINT_SPECIFICATIONS.length);
System.arraycopy(REFLECTION_ACCESS_BREAKPOINT_SPECIFICATIONS, 0, breakpointSpecifications, BREAKPOINT_SPECIFICATIONS.length, REFLECTION_ACCESS_BREAKPOINT_SPECIFICATIONS.length);
}
if (experimentalClassDefineSupport) {
BreakpointSpecification[] existingBreakpointSpecifications = breakpointSpecifications;
breakpointSpecifications = Arrays.copyOf(existingBreakpointSpecifications, existingBreakpointSpecifications.length + CLASS_PREDEFINITION_BREAKPOINT_SPECIFICATIONS.length);
System.arraycopy(CLASS_PREDEFINITION_BREAKPOINT_SPECIFICATIONS, 0, breakpointSpecifications, existingBreakpointSpecifications.length,
CLASS_PREDEFINITION_BREAKPOINT_SPECIFICATIONS.length);
}
for (BreakpointSpecification br : breakpointSpecifications) {
JNIObjectHandle clazz = nullHandle();
if (lastClassName != null && lastClassName.equals(br.className)) {
Expand Down Expand Up @@ -1652,6 +1724,10 @@ private static boolean allocateInstance(JNIEnvironment jni, JNIObjectHandle thre
brk("java/lang/reflect/Constructor", "newInstance", "([Ljava/lang/Object;)Ljava/lang/Object;", BreakpointInterceptor::invokeConstructor),
};

private static final BreakpointSpecification[] CLASS_PREDEFINITION_BREAKPOINT_SPECIFICATIONS = {
brk("java/lang/invoke/MethodHandles$Lookup$ClassFile", "<init>", "(Ljava/lang/String;I[B)V", BreakpointInterceptor::onMethodHandleClassFileInit),
};

private static BreakpointSpecification brk(String className, String methodName, String signature, BreakpointHandler handler) {
return new BreakpointSpecification(className, methodName, signature, handler, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet {

final JNIMethodId javaLangModuleGetName;

private JNIMethodId javaLangInvokeCallSiteMakeSite = WordFactory.nullPointer();
private JNIMethodId javaLangInvokeMethodHandleNativesLinkCallSiteImpl = WordFactory.nullPointer();
private JNIMethodId javaLangInvokeMethodHandleNativesLinkCallSite = WordFactory.nullPointer();

NativeImageAgentJNIHandleSet(JNIEnvironment env) {
super(env);
javaLangClass = newClassGlobalRef(env, "java/lang/Class");
Expand Down Expand Up @@ -263,4 +267,32 @@ public JNIMethodId getJavaUtilResourceBundleGetLocale(JNIEnvironment env) {
}
return javaUtilResourceBundleGetLocale;
}

public JNIMethodId getJavaLangInvokeCallSiteMakeSite(JNIEnvironment env) {
if (javaLangInvokeCallSiteMakeSite.isNull()) {
JNIObjectHandle javaLangInvokeCallSite = findClass(env, "java/lang/invoke/CallSite");
javaLangInvokeCallSiteMakeSite = getMethodId(env, javaLangInvokeCallSite, "makeSite",
"(Ljava/lang/invoke/MethodHandle;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/Object;Ljava/lang/Class;)Ljava/lang/invoke/CallSite;", true);
}
return javaLangInvokeCallSiteMakeSite;
}

public JNIMethodId getJavaLangInvokeMethodHandleNativesLinkCallSiteImpl(JNIEnvironment env) {
if (javaLangInvokeMethodHandleNativesLinkCallSiteImpl.isNull()) {
JNIObjectHandle javaLangInvokeMethodHandleNatives = findClass(env, "java/lang/invoke/MethodHandleNatives");
javaLangInvokeMethodHandleNativesLinkCallSiteImpl = getMethodId(env, javaLangInvokeMethodHandleNatives, "linkCallSiteImpl",
"(Ljava/lang/Class;Ljava/lang/invoke/MethodHandle;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/invoke/MemberName;",
true);
}
return javaLangInvokeMethodHandleNativesLinkCallSiteImpl;
}

public JNIMethodId getJavaLangInvokeMethodHandleNativesLinkCallSite(JNIEnvironment env) {
if (javaLangInvokeMethodHandleNativesLinkCallSite.isNull()) {
JNIObjectHandle javaLangInvokeMethodHandleNatives = findClass(env, "java/lang/invoke/MethodHandleNatives");
javaLangInvokeMethodHandleNativesLinkCallSite = getMethodIdOptional(env, javaLangInvokeMethodHandleNatives, "linkCallSite",
"(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/invoke/MemberName;", true);
}
return javaLangInvokeMethodHandleNativesLinkCallSite;
}
}
Loading

0 comments on commit 36af33b

Please sign in to comment.