Skip to content

Commit

Permalink
Fix issue 401 (#443)
Browse files Browse the repository at this point in the history
I ended up rewriting the entire control flow in a way that maps closely
to the `TypeValidator` implementation.

Fix #401
  • Loading branch information
andreaTP authored Jul 29, 2024
1 parent 26b8efd commit fd90dc2
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 82 deletions.
21 changes: 21 additions & 0 deletions runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.dylibso.chicory.runtime;

import com.dylibso.chicory.wasm.types.OpCode;

public class CtrlFrame {
// OpCode of the current Control Flow instruction
public final OpCode opCode;
// params or inputs
public final int startValues;
// returns or outputs
public final int endValues;
// the height of the stack before entering the current Control Flow instruction
public final int height;

public CtrlFrame(OpCode opCode, int startValues, int endValues, int height) {
this.opCode = opCode;
this.startValues = startValues;
this.endValues = endValues;
this.height = height;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ public static Value[] call(

var func = instance.function(funcId);
if (func != null) {
callStack.push(
new StackFrame(func.instructions(), instance, funcId, args, func.localTypes()));
var stackFrame =
new StackFrame(func.instructions(), instance, funcId, args, func.localTypes());
stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size());
callStack.push(stackFrame);

eval(stack, instance, callStack);
} else {
callStack.push(new StackFrame(instance, funcId, args, List.of()));
var stackFrame = new StackFrame(instance, funcId, args, List.of());
stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size());
callStack.push(stackFrame);

var results = instance.callHostFunction(funcId, args);
// a host function can return null or an array of ints
// which we will push onto the stack
Expand Down Expand Up @@ -119,27 +125,44 @@ static void eval(MStack stack, Instance instance, ArrayDeque<StackFrame> callSta
break;
case LOOP:
case BLOCK:
BLOCK(frame, stack);
BLOCK(frame, stack, instance, instruction);
break;
case IF:
IF(frame, stack, instruction);
IF(frame, stack, instance, instruction);
break;
case BR:
checkInterruption();
// fall through
case ELSE:
prepareControlTransfer(frame, stack, false);
frame.jumpTo(instruction.labelTrue());
break;
case BR:
BR(frame, stack, instruction);
break;
case BR_IF:
BR_IF(frame, stack, instruction);
break;
case BR_TABLE:
BR_TABLE(frame, stack, instruction);
break;
case END:
{
var ctrlFrame = frame.popCtrl();
StackFrame.doControlTransfer(ctrlFrame, stack);

// if this is the last end, then we're done with
// the function
if (frame.isLastBlock()) {
break loop;
}
break;
}
case RETURN:
shouldReturn = true;
break;
{
// RETURN doesn't pass through the END
var ctrlFrame = frame.popCtrlTillCall();
StackFrame.doControlTransfer(ctrlFrame, stack);

shouldReturn = true;
break;
}
case CALL_INDIRECT:
CALL_INDIRECT(stack, instance, callStack, operands);
break;
Expand All @@ -152,21 +175,6 @@ static void eval(MStack stack, Instance instance, ArrayDeque<StackFrame> callSta
case SELECT_T:
SELECT_T(stack, operands);
break;
case END:
{
if (frame.doControlTransfer && frame.isControlFrame) {
doControlTransfer(instance, stack, frame, instruction.scope());
} else {
frame.endOfNonControlBlock();
}

// if this is the last end, then we're done with
// the function
if (frame.isLastBlock()) {
break loop;
}
break;
}
case LOCAL_GET:
stack.push(frame.local((int) operands[0]));
break;
Expand Down Expand Up @@ -1844,9 +1852,15 @@ private static void CALL_INDIRECT(
call(stack, instance, callStack, funcId, args, type, false);
}

private static void BLOCK(StackFrame frame, MStack stack) {
frame.isControlFrame = true;
frame.registerStackSize(stack);
private static int numberOfParams(Instance instance, Instruction scope) {
var typeId = (int) scope.operands()[0];
if (typeId == 0x40) { // epsilon
return 0;
}
if (ValueType.isValid(typeId)) {
return 0;
}
return instance.type(typeId).params().size();
}

private static int numberOfValuesToReturn(Instance instance, Instruction scope) {
Expand All @@ -1863,63 +1877,65 @@ private static int numberOfValuesToReturn(Instance instance, Instruction scope)
return instance.type(typeId).returns().size();
}

private static void IF(StackFrame frame, MStack stack, Instruction instruction) {
frame.isControlFrame = false;
frame.registerStackSize(stack);
private static void BLOCK(
StackFrame frame, MStack stack, Instance instance, Instruction instruction) {
var paramsSize = numberOfParams(instance, instruction);
var returnsSize = numberOfValuesToReturn(instance, instruction);
frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize);
}

private static void IF(
StackFrame frame, MStack stack, Instance instance, Instruction instruction) {
var predValue = stack.pop();
var paramsSize = numberOfParams(instance, instruction);
var returnsSize = numberOfValuesToReturn(instance, instruction);
frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize);

frame.jumpTo(predValue.asInt() == 0 ? instruction.labelFalse() : instruction.labelTrue());
}

private static void ctrlJump(StackFrame frame, MStack stack, int n) {
var ctrlFrame = frame.popCtrl(n);
frame.pushCtrl(ctrlFrame);
// a LOOP jumps back to the first instruction without passing through an END
if (ctrlFrame.opCode == OpCode.LOOP) {
StackFrame.doControlTransfer(ctrlFrame, stack);
}
}

private static void BR(StackFrame frame, MStack stack, Instruction instruction) {
checkInterruption();
ctrlJump(frame, stack, (int) instruction.operands()[0]);
frame.jumpTo(instruction.labelTrue());
}

private static void BR_TABLE(StackFrame frame, MStack stack, Instruction instruction) {
var predValue = prepareControlTransfer(frame, stack, true);
var predValue = stack.pop();
var pred = predValue.asInt();

if (pred < 0 || pred >= instruction.labelTable().length - 1) {
var defaultIdx = instruction.operands().length - 1;
if (pred < 0 || pred >= defaultIdx) {
// choose default
frame.jumpTo(instruction.labelTable()[instruction.labelTable().length - 1]);
ctrlJump(frame, stack, (int) instruction.operands()[defaultIdx]);
frame.jumpTo(instruction.labelTable()[defaultIdx]);
} else {
ctrlJump(frame, stack, (int) instruction.operands()[pred]);
frame.jumpTo(instruction.labelTable()[pred]);
}
}

private static void BR_IF(StackFrame frame, MStack stack, Instruction instruction) {
var predValue = prepareControlTransfer(frame, stack, true);
var predValue = stack.pop();
var pred = predValue.asInt();

if (pred == 0) {
frame.jumpTo(instruction.labelFalse());
} else {
ctrlJump(frame, stack, (int) instruction.operands()[0]);
frame.jumpTo(instruction.labelTrue());
}
}

private static Value prepareControlTransfer(StackFrame frame, MStack stack, boolean consume) {
frame.doControlTransfer = true;
frame.isControlFrame = true;
return consume ? stack.pop() : null;
}

private static void doControlTransfer(
Instance instance, MStack stack, StackFrame frame, Instruction scope) {
// reset the control transfer
frame.doControlTransfer = false;

Value[] returns = new Value[numberOfValuesToReturn(instance, scope)];
for (int i = 0; i < returns.length; i++) {
if (stack.size() > 0) returns[i] = stack.pop();
}

// drop everything till the previous label
frame.dropValuesOutOfBlock(stack);

for (int i = 0; i < returns.length; i++) {
Value value = returns[returns.length - 1 - i];
if (value != null) {
stack.push(value);
}
}
}

@Override
public List<StackFrame> getStackTrace() {
return List.copyOf(callStack);
Expand Down
69 changes: 50 additions & 19 deletions runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.dylibso.chicory.runtime;

import com.dylibso.chicory.wasm.types.Instruction;
import com.dylibso.chicory.wasm.types.OpCode;
import com.dylibso.chicory.wasm.types.Value;
import com.dylibso.chicory.wasm.types.ValueType;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand All @@ -19,9 +20,6 @@
* within the function you are in and only specific places.
*/
public class StackFrame {
public boolean doControlTransfer = false;
public boolean isControlFrame = true;

private final List<Instruction> code;
private Instruction currentInstruction;

Expand All @@ -30,7 +28,7 @@ public class StackFrame {
private final Value[] locals;
private final Instance instance;

private final ArrayDeque<Integer> stackSizeBeforeBlock = new ArrayDeque<>();
private final List<CtrlFrame> ctrlStack = new ArrayList<>();

public StackFrame(Instance instance, int funcId, Value[] args, List<ValueType> localTypes) {
this(Collections.emptyList(), instance, funcId, args, localTypes);
Expand Down Expand Up @@ -89,30 +87,63 @@ public boolean terminated() {
return pc >= code.size();
}

public void registerStackSize(MStack stack) {
stackSizeBeforeBlock.push(stack.size());
public void pushCtrl(CtrlFrame ctrlFrame) {
ctrlStack.add(ctrlFrame);
}

public void jumpTo(int newPc) {
pc = newPc;
public void pushCtrl(OpCode opcode, int startValues, int returnValues, int height) {
ctrlStack.add(new CtrlFrame(opcode, startValues, returnValues, height));
}

public CtrlFrame popCtrl() {
var ctrlFrame = ctrlStack.remove(ctrlStack.size() - 1);
return ctrlFrame;
}

public void dropValuesOutOfBlock(MStack stack) {
if (currentInstruction.depth() > 0) {
while (stackSizeBeforeBlock.size() > currentInstruction.depth()) {
stackSizeBeforeBlock.pop();
public CtrlFrame popCtrl(int n) {
int mostRecentCallHeight = ctrlStack.size();
while (true) {
if (ctrlStack.get(--mostRecentCallHeight).opCode == OpCode.CALL) {
break;
}
}
var finalHeight = ctrlStack.size() - (mostRecentCallHeight + n + 1);
CtrlFrame ctrlFrame = null;
while (ctrlStack.size() > finalHeight) {
ctrlFrame = popCtrl();
}
return ctrlFrame;
}

int expectedStackSize = stackSizeBeforeBlock.pop();
while (stack.size() > expectedStackSize) {
stack.pop();
public CtrlFrame popCtrlTillCall() {
while (true) {
var ctrlFrame = popCtrl();
if (ctrlFrame.opCode == OpCode.CALL) {
return ctrlFrame;
}
}
}

public void endOfNonControlBlock() {
if (currentInstruction.depth() > 0) {
stackSizeBeforeBlock.pop();
public void jumpTo(int newPc) {
pc = newPc;
}

public static void doControlTransfer(CtrlFrame ctrlFrame, MStack stack) {
var endResults = ctrlFrame.startValues + ctrlFrame.endValues; // unwind stack
Value[] returns = new Value[endResults];
for (int i = 0; i < returns.length; i++) {
if (stack.size() > 0) returns[i] = stack.pop();
}

while (stack.size() > ctrlFrame.height) {
stack.pop();
}

for (int i = 0; i < returns.length; i++) {
Value value = returns[returns.length - 1 - i];
if (value != null) {
stack.push(value);
}
}
}
}
21 changes: 21 additions & 0 deletions runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,25 @@ public void shouldValidateTypes() {
.build()
.instantiate());
}

@Test
public void shouldConsumeStackLoopOperations() {
AtomicLong finalStackSize = new AtomicLong(0);
var instance =
Module.builder("compiled/fac.wat.wasm")
.withUnsafeExecutionListener(
(Instruction instruction, long[] operands, MStack stack) -> {
finalStackSize.set(stack.size());
})
.build()
.instantiate();
var facSsa = instance.export("fac-ssa");

var number = 100;
var result = facSsa.apply(Value.i32(number));
assertEquals(factorial(number), result[0].asInt());

// IIUC: 3 values returning from last CALL + 1 result
assertTrue(finalStackSize.get() == 4L);
}
}
Binary file added wasm-corpus/src/main/resources/compiled/fac.wat.wasm
Binary file not shown.
Loading

0 comments on commit fd90dc2

Please sign in to comment.