From fd90dc236843db366a3d32f8607ebcce8bd15ee3 Mon Sep 17 00:00:00 2001 From: Andrea Peruffo Date: Mon, 29 Jul 2024 11:43:00 +0100 Subject: [PATCH] Fix issue 401 (#443) I ended up rewriting the entire control flow in a way that maps closely to the `TypeValidator` implementation. Fix #401 --- .../dylibso/chicory/runtime/CtrlFrame.java | 21 +++ .../chicory/runtime/InterpreterMachine.java | 142 ++++++++++-------- .../dylibso/chicory/runtime/StackFrame.java | 69 ++++++--- .../dylibso/chicory/runtime/ModuleTest.java | 21 +++ .../src/main/resources/compiled/fac.wat.wasm | Bin 0 -> 103 bytes wasm-corpus/src/main/resources/wat/fac.wat | 29 ++++ 6 files changed, 200 insertions(+), 82 deletions(-) create mode 100644 runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java create mode 100644 wasm-corpus/src/main/resources/compiled/fac.wat.wasm create mode 100644 wasm-corpus/src/main/resources/wat/fac.wat diff --git a/runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java b/runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java new file mode 100644 index 000000000..1ce808bc4 --- /dev/null +++ b/runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java @@ -0,0 +1,21 @@ +package com.dylibso.chicory.runtime; + +import com.dylibso.chicory.wasm.types.OpCode; + +public class CtrlFrame { + // OpCode of the current Control Flow instruction + public final OpCode opCode; + // params or inputs + public final int startValues; + // returns or outputs + public final int endValues; + // the height of the stack before entering the current Control Flow instruction + public final int height; + + public CtrlFrame(OpCode opCode, int startValues, int endValues, int height) { + this.opCode = opCode; + this.startValues = startValues; + this.endValues = endValues; + this.height = height; + } +} diff --git a/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java b/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java index bd4c5a225..f6b35b447 100644 --- a/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java +++ b/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java @@ -54,11 +54,17 @@ public static Value[] call( var func = instance.function(funcId); if (func != null) { - callStack.push( - new StackFrame(func.instructions(), instance, funcId, args, func.localTypes())); + var stackFrame = + new StackFrame(func.instructions(), instance, funcId, args, func.localTypes()); + stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size()); + callStack.push(stackFrame); + eval(stack, instance, callStack); } else { - callStack.push(new StackFrame(instance, funcId, args, List.of())); + var stackFrame = new StackFrame(instance, funcId, args, List.of()); + stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size()); + callStack.push(stackFrame); + var results = instance.callHostFunction(funcId, args); // a host function can return null or an array of ints // which we will push onto the stack @@ -119,27 +125,44 @@ static void eval(MStack stack, Instance instance, ArrayDeque callSta break; case LOOP: case BLOCK: - BLOCK(frame, stack); + BLOCK(frame, stack, instance, instruction); break; case IF: - IF(frame, stack, instruction); + IF(frame, stack, instance, instruction); break; - case BR: - checkInterruption(); - // fall through case ELSE: - prepareControlTransfer(frame, stack, false); frame.jumpTo(instruction.labelTrue()); break; + case BR: + BR(frame, stack, instruction); + break; case BR_IF: BR_IF(frame, stack, instruction); break; case BR_TABLE: BR_TABLE(frame, stack, instruction); break; + case END: + { + var ctrlFrame = frame.popCtrl(); + StackFrame.doControlTransfer(ctrlFrame, stack); + + // if this is the last end, then we're done with + // the function + if (frame.isLastBlock()) { + break loop; + } + break; + } case RETURN: - shouldReturn = true; - break; + { + // RETURN doesn't pass through the END + var ctrlFrame = frame.popCtrlTillCall(); + StackFrame.doControlTransfer(ctrlFrame, stack); + + shouldReturn = true; + break; + } case CALL_INDIRECT: CALL_INDIRECT(stack, instance, callStack, operands); break; @@ -152,21 +175,6 @@ static void eval(MStack stack, Instance instance, ArrayDeque callSta case SELECT_T: SELECT_T(stack, operands); break; - case END: - { - if (frame.doControlTransfer && frame.isControlFrame) { - doControlTransfer(instance, stack, frame, instruction.scope()); - } else { - frame.endOfNonControlBlock(); - } - - // if this is the last end, then we're done with - // the function - if (frame.isLastBlock()) { - break loop; - } - break; - } case LOCAL_GET: stack.push(frame.local((int) operands[0])); break; @@ -1844,9 +1852,15 @@ private static void CALL_INDIRECT( call(stack, instance, callStack, funcId, args, type, false); } - private static void BLOCK(StackFrame frame, MStack stack) { - frame.isControlFrame = true; - frame.registerStackSize(stack); + private static int numberOfParams(Instance instance, Instruction scope) { + var typeId = (int) scope.operands()[0]; + if (typeId == 0x40) { // epsilon + return 0; + } + if (ValueType.isValid(typeId)) { + return 0; + } + return instance.type(typeId).params().size(); } private static int numberOfValuesToReturn(Instance instance, Instruction scope) { @@ -1863,63 +1877,65 @@ private static int numberOfValuesToReturn(Instance instance, Instruction scope) return instance.type(typeId).returns().size(); } - private static void IF(StackFrame frame, MStack stack, Instruction instruction) { - frame.isControlFrame = false; - frame.registerStackSize(stack); + private static void BLOCK( + StackFrame frame, MStack stack, Instance instance, Instruction instruction) { + var paramsSize = numberOfParams(instance, instruction); + var returnsSize = numberOfValuesToReturn(instance, instruction); + frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize); + } + + private static void IF( + StackFrame frame, MStack stack, Instance instance, Instruction instruction) { var predValue = stack.pop(); + var paramsSize = numberOfParams(instance, instruction); + var returnsSize = numberOfValuesToReturn(instance, instruction); + frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize); + frame.jumpTo(predValue.asInt() == 0 ? instruction.labelFalse() : instruction.labelTrue()); } + private static void ctrlJump(StackFrame frame, MStack stack, int n) { + var ctrlFrame = frame.popCtrl(n); + frame.pushCtrl(ctrlFrame); + // a LOOP jumps back to the first instruction without passing through an END + if (ctrlFrame.opCode == OpCode.LOOP) { + StackFrame.doControlTransfer(ctrlFrame, stack); + } + } + + private static void BR(StackFrame frame, MStack stack, Instruction instruction) { + checkInterruption(); + ctrlJump(frame, stack, (int) instruction.operands()[0]); + frame.jumpTo(instruction.labelTrue()); + } + private static void BR_TABLE(StackFrame frame, MStack stack, Instruction instruction) { - var predValue = prepareControlTransfer(frame, stack, true); + var predValue = stack.pop(); var pred = predValue.asInt(); - if (pred < 0 || pred >= instruction.labelTable().length - 1) { + var defaultIdx = instruction.operands().length - 1; + if (pred < 0 || pred >= defaultIdx) { // choose default - frame.jumpTo(instruction.labelTable()[instruction.labelTable().length - 1]); + ctrlJump(frame, stack, (int) instruction.operands()[defaultIdx]); + frame.jumpTo(instruction.labelTable()[defaultIdx]); } else { + ctrlJump(frame, stack, (int) instruction.operands()[pred]); frame.jumpTo(instruction.labelTable()[pred]); } } private static void BR_IF(StackFrame frame, MStack stack, Instruction instruction) { - var predValue = prepareControlTransfer(frame, stack, true); + var predValue = stack.pop(); var pred = predValue.asInt(); if (pred == 0) { frame.jumpTo(instruction.labelFalse()); } else { + ctrlJump(frame, stack, (int) instruction.operands()[0]); frame.jumpTo(instruction.labelTrue()); } } - private static Value prepareControlTransfer(StackFrame frame, MStack stack, boolean consume) { - frame.doControlTransfer = true; - frame.isControlFrame = true; - return consume ? stack.pop() : null; - } - - private static void doControlTransfer( - Instance instance, MStack stack, StackFrame frame, Instruction scope) { - // reset the control transfer - frame.doControlTransfer = false; - - Value[] returns = new Value[numberOfValuesToReturn(instance, scope)]; - for (int i = 0; i < returns.length; i++) { - if (stack.size() > 0) returns[i] = stack.pop(); - } - - // drop everything till the previous label - frame.dropValuesOutOfBlock(stack); - - for (int i = 0; i < returns.length; i++) { - Value value = returns[returns.length - 1 - i]; - if (value != null) { - stack.push(value); - } - } - } - @Override public List getStackTrace() { return List.copyOf(callStack); diff --git a/runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java b/runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java index 721fc40ce..5d9a3c582 100644 --- a/runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java +++ b/runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java @@ -1,9 +1,10 @@ package com.dylibso.chicory.runtime; import com.dylibso.chicory.wasm.types.Instruction; +import com.dylibso.chicory.wasm.types.OpCode; import com.dylibso.chicory.wasm.types.Value; import com.dylibso.chicory.wasm.types.ValueType; -import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -19,9 +20,6 @@ * within the function you are in and only specific places. */ public class StackFrame { - public boolean doControlTransfer = false; - public boolean isControlFrame = true; - private final List code; private Instruction currentInstruction; @@ -30,7 +28,7 @@ public class StackFrame { private final Value[] locals; private final Instance instance; - private final ArrayDeque stackSizeBeforeBlock = new ArrayDeque<>(); + private final List ctrlStack = new ArrayList<>(); public StackFrame(Instance instance, int funcId, Value[] args, List localTypes) { this(Collections.emptyList(), instance, funcId, args, localTypes); @@ -89,30 +87,63 @@ public boolean terminated() { return pc >= code.size(); } - public void registerStackSize(MStack stack) { - stackSizeBeforeBlock.push(stack.size()); + public void pushCtrl(CtrlFrame ctrlFrame) { + ctrlStack.add(ctrlFrame); } - public void jumpTo(int newPc) { - pc = newPc; + public void pushCtrl(OpCode opcode, int startValues, int returnValues, int height) { + ctrlStack.add(new CtrlFrame(opcode, startValues, returnValues, height)); + } + + public CtrlFrame popCtrl() { + var ctrlFrame = ctrlStack.remove(ctrlStack.size() - 1); + return ctrlFrame; } - public void dropValuesOutOfBlock(MStack stack) { - if (currentInstruction.depth() > 0) { - while (stackSizeBeforeBlock.size() > currentInstruction.depth()) { - stackSizeBeforeBlock.pop(); + public CtrlFrame popCtrl(int n) { + int mostRecentCallHeight = ctrlStack.size(); + while (true) { + if (ctrlStack.get(--mostRecentCallHeight).opCode == OpCode.CALL) { + break; } + } + var finalHeight = ctrlStack.size() - (mostRecentCallHeight + n + 1); + CtrlFrame ctrlFrame = null; + while (ctrlStack.size() > finalHeight) { + ctrlFrame = popCtrl(); + } + return ctrlFrame; + } - int expectedStackSize = stackSizeBeforeBlock.pop(); - while (stack.size() > expectedStackSize) { - stack.pop(); + public CtrlFrame popCtrlTillCall() { + while (true) { + var ctrlFrame = popCtrl(); + if (ctrlFrame.opCode == OpCode.CALL) { + return ctrlFrame; } } } - public void endOfNonControlBlock() { - if (currentInstruction.depth() > 0) { - stackSizeBeforeBlock.pop(); + public void jumpTo(int newPc) { + pc = newPc; + } + + public static void doControlTransfer(CtrlFrame ctrlFrame, MStack stack) { + var endResults = ctrlFrame.startValues + ctrlFrame.endValues; // unwind stack + Value[] returns = new Value[endResults]; + for (int i = 0; i < returns.length; i++) { + if (stack.size() > 0) returns[i] = stack.pop(); + } + + while (stack.size() > ctrlFrame.height) { + stack.pop(); + } + + for (int i = 0; i < returns.length; i++) { + Value value = returns[returns.length - 1 - i]; + if (value != null) { + stack.push(value); + } } } } diff --git a/runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java b/runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java index c8bc0093c..b2e9bb336 100644 --- a/runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java +++ b/runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java @@ -355,4 +355,25 @@ public void shouldValidateTypes() { .build() .instantiate()); } + + @Test + public void shouldConsumeStackLoopOperations() { + AtomicLong finalStackSize = new AtomicLong(0); + var instance = + Module.builder("compiled/fac.wat.wasm") + .withUnsafeExecutionListener( + (Instruction instruction, long[] operands, MStack stack) -> { + finalStackSize.set(stack.size()); + }) + .build() + .instantiate(); + var facSsa = instance.export("fac-ssa"); + + var number = 100; + var result = facSsa.apply(Value.i32(number)); + assertEquals(factorial(number), result[0].asInt()); + + // IIUC: 3 values returning from last CALL + 1 result + assertTrue(finalStackSize.get() == 4L); + } } diff --git a/wasm-corpus/src/main/resources/compiled/fac.wat.wasm b/wasm-corpus/src/main/resources/compiled/fac.wat.wasm new file mode 100644 index 0000000000000000000000000000000000000000..a250bd7a7a0d07557a4f0ce20b2f070da216be80 GIT binary patch literal 103 zcmXBGF$#b%6a&ztzo=Nm&c(qC=plLt=N7@usk8s=evy#8pthR;n0%p9ZJcVcivCi? m3my8v{;KEJT7?^T>9AVD@IgBZvLO-~sgapaM8+{8%^|$$rw|4J literal 0 HcmV?d00001 diff --git a/wasm-corpus/src/main/resources/wat/fac.wat b/wasm-corpus/src/main/resources/wat/fac.wat new file mode 100644 index 000000000..e52511899 --- /dev/null +++ b/wasm-corpus/src/main/resources/wat/fac.wat @@ -0,0 +1,29 @@ +(module + (type (;0;) (func (param i64) (result i64))) + (type (;1;) (func (param i64) (result i64 i64))) + (type (;2;) (func (param i64 i64) (result i64 i64 i64))) + (func (;0;) (type 1) (param i64) (result i64 i64) + local.get 0 + local.get 0) + (func (;1;) (type 2) (param i64 i64) (result i64 i64 i64) + local.get 0 + local.get 1 + local.get 0) + (func (;2;) (type 0) (param i64) (result i64) + i64.const 1 + local.get 0 + loop (param i64 i64) (result i64) ;; label = @1 + call 1 + call 1 + i64.mul + call 1 + i64.const 1 + i64.sub + call 0 + i64.const 0 + i64.gt_u + br_if 0 (;@1;) + drop + return + end) + (export "fac-ssa" (func 2)))