Skip to content

Commit

Permalink
Optimize AOT CALL_INDIRECT for calls in the same module
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum authored and andreaTP committed Oct 3, 2024
1 parent 4c21d93 commit 6c40a05
Show file tree
Hide file tree
Showing 15 changed files with 678 additions and 102 deletions.
3 changes: 2 additions & 1 deletion aot/src/main/java/com/dylibso/chicory/aot/AotEmitters.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ public static void CALL_INDIRECT(AotContext ctx, AnnotatedInstruction ins, Metho
FunctionType functionType = ctx.types()[typeId];

asm.visitLdcInsn(tableIdx);
asm.visitVarInsn(Opcodes.ALOAD, ctx.memorySlot());
asm.visitVarInsn(Opcodes.ALOAD, ctx.instanceSlot());
// stack: arguments, funcTableIdx, tableIdx, instance
// stack: arguments, funcTableIdx, tableIdx, memory, instance

asm.visitMethodInsn(
Opcodes.INVOKESTATIC,
Expand Down
98 changes: 91 additions & 7 deletions aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import static com.dylibso.chicory.aot.AotMethods.CHECK_INTERRUPTION;
import static com.dylibso.chicory.aot.AotMethods.INSTANCE_CALL_HOST_FUNCTION;
import static com.dylibso.chicory.aot.AotMethods.INSTANCE_TABLE;
import static com.dylibso.chicory.aot.AotMethods.TABLE_INSTANCE;
import static com.dylibso.chicory.aot.AotMethods.TABLE_REF;
import static com.dylibso.chicory.aot.AotMethods.THROW_INDIRECT_CALL_TYPE_MISMATCH;
import static com.dylibso.chicory.aot.AotMethods.THROW_TRAP_EXCEPTION;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodName;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodType;
import static com.dylibso.chicory.aot.AotUtil.defaultValue;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeFunction;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeStatic;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeVirtual;
import static com.dylibso.chicory.aot.AotUtil.emitJvmToLong;
Expand Down Expand Up @@ -57,6 +62,7 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -444,7 +450,7 @@ private byte[] compileClass(String className, FunctionSection functions) {
classWriter,
callIndirectMethodName(typeId),
callIndirectMethodType(type),
asm -> compileCallIndirect(asm, typeId, type));
asm -> compileCallIndirect(internalClassName, typeId, type, asm));
}

var returnTypes =
Expand Down Expand Up @@ -560,15 +566,93 @@ private static void emitConstructor(ClassVisitor writer) {
cons.visitEnd();
}

private static void compileCallIndirect(MethodVisitor asm, int typeId, FunctionType type) {
int slot = type.params().stream().mapToInt(AotUtil::slotCount).sum();
private void compileCallIndirect(
String internalClassName, int typeId, FunctionType type, MethodVisitor asm) {
int slots = type.params().stream().mapToInt(AotUtil::slotCount).sum();
int funcTableIdx = slots;
int tableIdx = slots + 1;
int memory = slots + 2;
int instance = slots + 3;
int table = slots + 4;
int funcId = slots + 5;
int refInstance = slots + 6;

emitInvokeStatic(asm, CHECK_INTERRUPTION);

// TableInstance table = instance.table(tableIdx);
asm.visitVarInsn(Opcodes.ALOAD, instance);
asm.visitVarInsn(Opcodes.ILOAD, tableIdx);
emitInvokeVirtual(asm, INSTANCE_TABLE);
asm.visitVarInsn(Opcodes.ASTORE, table);

// int funcId = tableRef(table, funcTableIdx);
asm.visitVarInsn(Opcodes.ALOAD, table);
asm.visitVarInsn(Opcodes.ILOAD, funcTableIdx);
emitInvokeStatic(asm, TABLE_REF);
asm.visitVarInsn(Opcodes.ISTORE, funcId);

// Instance refInstance = table.instance(funcTableIdx);
asm.visitVarInsn(Opcodes.ALOAD, table);
asm.visitVarInsn(Opcodes.ILOAD, funcTableIdx);
emitInvokeVirtual(asm, TABLE_INSTANCE);
asm.visitVarInsn(Opcodes.ASTORE, refInstance);

Label local = new Label();
Label other = new Label();

// if (refInstance == null || refInstance == instance)
asm.visitVarInsn(Opcodes.ALOAD, refInstance);
asm.visitJumpInsn(Opcodes.IFNULL, local);
asm.visitVarInsn(Opcodes.ALOAD, refInstance);
asm.visitVarInsn(Opcodes.ALOAD, instance);
asm.visitJumpInsn(Opcodes.IF_ACMPNE, other);

// local: call function in this module
asm.visitLabel(local);

int slot = 0;
for (ValueType param : type.params()) {
asm.visitVarInsn(loadTypeOpcode(param), slot);
slot += slotCount(param);
}
asm.visitVarInsn(Opcodes.ALOAD, memory);
asm.visitVarInsn(Opcodes.ALOAD, instance);

List<Integer> validIds = new ArrayList<>();
for (int i = 0; i < functionTypes.size(); i++) {
if (type.equals(functionTypes.get(i))) {
validIds.add(i);
}
}

Label invalid = new Label();
int[] keys = validIds.stream().mapToInt(x -> x).toArray();
Label[] labels = validIds.stream().map(x -> new Label()).toArray(Label[]::new);

asm.visitVarInsn(Opcodes.ILOAD, funcId);
asm.visitLookupSwitchInsn(invalid, keys, labels);

Label done = new Label();
for (int i = 0; i < validIds.size(); i++) {
asm.visitLabel(labels[i]);
emitInvokeFunction(asm, internalClassName, keys[i], type);
asm.visitJumpInsn(Opcodes.GOTO, done);
}

asm.visitLabel(invalid);
emitInvokeStatic(asm, THROW_INDIRECT_CALL_TYPE_MISMATCH);
asm.visitInsn(Opcodes.ATHROW);

asm.visitLabel(done);
asm.visitInsn(returnTypeOpcode(type));

// other: call function in another module
asm.visitLabel(other);

// parameters: arguments, funcTableIdx, tableIdx, instance
emitBoxArguments(asm, type.params());
asm.visitLdcInsn(typeId);
asm.visitVarInsn(Opcodes.ILOAD, slot); // funcTableIdx
asm.visitVarInsn(Opcodes.ILOAD, slot + 1); // tableIdx
asm.visitVarInsn(Opcodes.ALOAD, slot + 2); // instance
asm.visitVarInsn(Opcodes.ILOAD, funcId);
asm.visitVarInsn(Opcodes.ALOAD, refInstance);

emitInvokeStatic(asm, AotMethods.CALL_INDIRECT);

Expand Down
47 changes: 26 additions & 21 deletions aot/src/main/java/com/dylibso/chicory/aot/AotMethods.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.dylibso.chicory.aot;

import static com.dylibso.chicory.wasm.types.Value.REF_NULL_VALUE;
import static java.util.Objects.requireNonNullElse;

import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.runtime.Memory;
Expand All @@ -23,6 +22,7 @@ public final class AotMethods {
static final Method INSTANCE_READ_GLOBAL;
static final Method WRITE_GLOBAL;
static final Method INSTANCE_SET_ELEMENT;
static final Method INSTANCE_TABLE;
static final Method MEMORY_COPY;
static final Method MEMORY_FILL;
static final Method MEMORY_INIT;
Expand All @@ -49,7 +49,10 @@ public final class AotMethods {
static final Method TABLE_FILL;
static final Method TABLE_COPY;
static final Method TABLE_INIT;
static final Method TABLE_REF;
static final Method TABLE_INSTANCE;
static final Method VALIDATE_BASE;
static final Method THROW_INDIRECT_CALL_TYPE_MISMATCH;
static final Method THROW_OUT_OF_BOUNDS_MEMORY_ACCESS;
static final Method THROW_TRAP_EXCEPTION;

Expand All @@ -58,19 +61,15 @@ public final class AotMethods {
CHECK_INTERRUPTION = AotMethods.class.getMethod("checkInterruption");
CALL_INDIRECT =
AotMethods.class.getMethod(
"callIndirect",
long[].class,
int.class,
int.class,
int.class,
Instance.class);
"callIndirect", long[].class, int.class, int.class, Instance.class);
INSTANCE_CALL_HOST_FUNCTION =
Instance.class.getMethod("callHostFunction", int.class, long[].class);
INSTANCE_READ_GLOBAL = Instance.class.getMethod("readGlobal", int.class);
WRITE_GLOBAL =
AotMethods.class.getMethod(
"writeGlobal", long.class, int.class, Instance.class);
INSTANCE_SET_ELEMENT = Instance.class.getMethod("setElement", int.class, Element.class);
INSTANCE_TABLE = Instance.class.getMethod("table", int.class);
MEMORY_COPY =
AotMethods.class.getMethod(
"memoryCopy", int.class, int.class, int.class, Memory.class);
Expand Down Expand Up @@ -154,7 +153,11 @@ public final class AotMethods {
int.class,
int.class,
Instance.class);
TABLE_REF = AotMethods.class.getMethod("tableRef", TableInstance.class, int.class);
TABLE_INSTANCE = TableInstance.class.getMethod("instance", int.class);
VALIDATE_BASE = AotMethods.class.getMethod("validateBase", int.class);
THROW_INDIRECT_CALL_TYPE_MISMATCH =
AotMethods.class.getMethod("throwIndirectCallTypeMismatch");
THROW_OUT_OF_BOUNDS_MEMORY_ACCESS =
AotMethods.class.getMethod("throwOutOfBoundsMemoryAccess");
THROW_TRAP_EXCEPTION = AotMethods.class.getMethod("throwTrapException");
Expand All @@ -166,24 +169,12 @@ public final class AotMethods {
private AotMethods() {}

@UsedByGeneratedCode
public static long[] callIndirect(
long[] args, int typeId, int funcTableIdx, int tableIdx, Instance instance) {
TableInstance table = instance.table(tableIdx);

instance = requireNonNullElse(table.instance(funcTableIdx), instance);

int funcId = table.ref(funcTableIdx);
if (funcId == REF_NULL_VALUE) {
throw new ChicoryException("uninitialized element " + funcTableIdx);
}

public static long[] callIndirect(long[] args, int typeId, int funcId, Instance instance) {
FunctionType expectedType = instance.type(typeId);
FunctionType actualType = instance.type(instance.functionType(funcId));
if (!actualType.typesMatch(expectedType)) {
throw new ChicoryException("indirect call type mismatch");
throw throwIndirectCallTypeMismatch();
}

checkInterruption();
return instance.getMachine().call(funcId, args);
}

Expand Down Expand Up @@ -230,6 +221,15 @@ public static void tableInit(
OpcodeImpl.TABLE_INIT(instance, tableidx, elementidx, size, elemidx, offset);
}

@UsedByGeneratedCode
public static int tableRef(TableInstance table, int index) {
int funcId = table.ref(index);
if (funcId == REF_NULL_VALUE) {
throw new ChicoryException("uninitialized element " + index);
}
return funcId;
}

@UsedByGeneratedCode
public static void memoryCopy(int destination, int offset, int size, Memory memory) {
memory.copy(destination, offset, size);
Expand Down Expand Up @@ -326,6 +326,11 @@ public static void validateBase(int base) {
}
}

@UsedByGeneratedCode
public static RuntimeException throwIndirectCallTypeMismatch() {
return new ChicoryException("indirect call type mismatch");
}

@UsedByGeneratedCode
public static RuntimeException throwOutOfBoundsMemoryAccess() {
throw new WASMRuntimeException("out of bounds memory access");
Expand Down
2 changes: 1 addition & 1 deletion aot/src/main/java/com/dylibso/chicory/aot/AotUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public static MethodHandle jvmToLongHandle(ValueType type) {

public static MethodType callIndirectMethodType(FunctionType functionType) {
return rawMethodTypeFor(functionType)
.appendParameterTypes(int.class, int.class, Instance.class);
.appendParameterTypes(int.class, int.class, Memory.class, Instance.class);
}

public static MethodType methodTypeFor(FunctionType type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,42 @@ public final class com/dylibso/chicory/$gen/CompiledModule {
ATHROW

// access flags 0x9
public static call_indirect_0(IIILcom/dylibso/chicory/runtime/Instance;)I
public static call_indirect_0(IIILcom/dylibso/chicory/runtime/Memory;Lcom/dylibso/chicory/runtime/Instance;)I
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.checkInterruption ()V
ALOAD 4
ILOAD 2
INVOKEVIRTUAL com/dylibso/chicory/runtime/Instance.table (I)Lcom/dylibso/chicory/runtime/TableInstance;
ASTORE 5
ALOAD 5
ILOAD 1
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.tableRef (Lcom/dylibso/chicory/runtime/TableInstance;I)I
ISTORE 6
ALOAD 5
ILOAD 1
INVOKEVIRTUAL com/dylibso/chicory/runtime/TableInstance.instance (I)Lcom/dylibso/chicory/runtime/Instance;
ASTORE 7
ALOAD 7
IFNULL L0
ALOAD 7
ALOAD 4
IF_ACMPNE L1
L0
ILOAD 0
ALOAD 3
ALOAD 4
ILOAD 6
LOOKUPSWITCH
0: L2
default: L3
L2
INVOKESTATIC com/dylibso/chicory/$gen/CompiledModule.func_0 (ILcom/dylibso/chicory/runtime/Memory;Lcom/dylibso/chicory/runtime/Instance;)I
GOTO L4
L3
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.throwIndirectCallTypeMismatch ()Ljava/lang/RuntimeException;
ATHROW
L4
IRETURN
L1
ICONST_1
NEWARRAY T_LONG
DUP
Expand All @@ -53,10 +88,9 @@ public final class com/dylibso/chicory/$gen/CompiledModule {
I2L
LASTORE
ICONST_0
ILOAD 1
ILOAD 2
ALOAD 3
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.callIndirect ([JIIILcom/dylibso/chicory/runtime/Instance;)[J
ILOAD 6
ALOAD 7
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.callIndirect ([JIILcom/dylibso/chicory/runtime/Instance;)[J
ICONST_0
LALOAD
L2I
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,42 @@ public final class com/dylibso/chicory/$gen/CompiledModule {
IRETURN

// access flags 0x9
public static call_indirect_0(IIILcom/dylibso/chicory/runtime/Instance;)I
public static call_indirect_0(IIILcom/dylibso/chicory/runtime/Memory;Lcom/dylibso/chicory/runtime/Instance;)I
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.checkInterruption ()V
ALOAD 4
ILOAD 2
INVOKEVIRTUAL com/dylibso/chicory/runtime/Instance.table (I)Lcom/dylibso/chicory/runtime/TableInstance;
ASTORE 5
ALOAD 5
ILOAD 1
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.tableRef (Lcom/dylibso/chicory/runtime/TableInstance;I)I
ISTORE 6
ALOAD 5
ILOAD 1
INVOKEVIRTUAL com/dylibso/chicory/runtime/TableInstance.instance (I)Lcom/dylibso/chicory/runtime/Instance;
ASTORE 7
ALOAD 7
IFNULL L0
ALOAD 7
ALOAD 4
IF_ACMPNE L1
L0
ILOAD 0
ALOAD 3
ALOAD 4
ILOAD 6
LOOKUPSWITCH
0: L2
default: L3
L2
INVOKESTATIC com/dylibso/chicory/$gen/CompiledModule.func_0 (ILcom/dylibso/chicory/runtime/Memory;Lcom/dylibso/chicory/runtime/Instance;)I
GOTO L4
L3
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.throwIndirectCallTypeMismatch ()Ljava/lang/RuntimeException;
ATHROW
L4
IRETURN
L1
ICONST_1
NEWARRAY T_LONG
DUP
Expand All @@ -50,10 +85,9 @@ public final class com/dylibso/chicory/$gen/CompiledModule {
I2L
LASTORE
ICONST_0
ILOAD 1
ILOAD 2
ALOAD 3
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.callIndirect ([JIIILcom/dylibso/chicory/runtime/Instance;)[J
ILOAD 6
ALOAD 7
INVOKESTATIC com/dylibso/chicory/aot/AotMethods.callIndirect ([JIILcom/dylibso/chicory/runtime/Instance;)[J
ICONST_0
LALOAD
L2I
Expand Down
Loading

0 comments on commit 6c40a05

Please sign in to comment.