Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 401 #443

Merged
merged 14 commits into from
Jul 29, 2024
21 changes: 21 additions & 0 deletions runtime/src/main/java/com/dylibso/chicory/runtime/CtrlFrame.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.dylibso.chicory.runtime;

import com.dylibso.chicory.wasm.types.OpCode;

public class CtrlFrame {
// OpCode of the current Control Flow instruction
public final OpCode opCode;
// params or inputs
public final int startValues;
// returns or outputs
public final int endValues;
// the height of the stack before entering the current Control Flow instruction
public final int height;

public CtrlFrame(OpCode opCode, int startValues, int endValues, int height) {
this.opCode = opCode;
this.startValues = startValues;
this.endValues = endValues;
this.height = height;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ public static Value[] call(

var func = instance.function(funcId);
if (func != null) {
callStack.push(
new StackFrame(func.instructions(), instance, funcId, args, func.localTypes()));
var stackFrame =
new StackFrame(func.instructions(), instance, funcId, args, func.localTypes());
stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size());
callStack.push(stackFrame);

eval(stack, instance, callStack);
} else {
callStack.push(new StackFrame(instance, funcId, args, List.of()));
var stackFrame = new StackFrame(instance, funcId, args, List.of());
stackFrame.pushCtrl(OpCode.CALL, 0, type.returns().size(), stack.size());
callStack.push(stackFrame);

var results = instance.callHostFunction(funcId, args);
// a host function can return null or an array of ints
// which we will push onto the stack
Expand Down Expand Up @@ -119,27 +125,44 @@ static void eval(MStack stack, Instance instance, ArrayDeque<StackFrame> callSta
break;
case LOOP:
case BLOCK:
BLOCK(frame, stack);
BLOCK(frame, stack, instance, instruction);
break;
case IF:
IF(frame, stack, instruction);
IF(frame, stack, instance, instruction);
break;
case BR:
checkInterruption();
// fall through
case ELSE:
prepareControlTransfer(frame, stack, false);
frame.jumpTo(instruction.labelTrue());
break;
case BR:
BR(frame, stack, instruction);
break;
case BR_IF:
BR_IF(frame, stack, instruction);
break;
case BR_TABLE:
BR_TABLE(frame, stack, instruction);
break;
case END:
{
var ctrlFrame = frame.popCtrl();
StackFrame.doControlTransfer(ctrlFrame, stack);

// if this is the last end, then we're done with
// the function
if (frame.isLastBlock()) {
break loop;
}
break;
}
case RETURN:
shouldReturn = true;
break;
{
// RETURN doesn't pass through the END
var ctrlFrame = frame.popCtrlTillCall();
StackFrame.doControlTransfer(ctrlFrame, stack);

shouldReturn = true;
break;
}
case CALL_INDIRECT:
CALL_INDIRECT(stack, instance, callStack, operands);
break;
Expand All @@ -152,21 +175,6 @@ static void eval(MStack stack, Instance instance, ArrayDeque<StackFrame> callSta
case SELECT_T:
SELECT_T(stack, operands);
break;
case END:
{
if (frame.doControlTransfer && frame.isControlFrame) {
doControlTransfer(instance, stack, frame, instruction.scope());
} else {
frame.endOfNonControlBlock();
}

// if this is the last end, then we're done with
// the function
if (frame.isLastBlock()) {
break loop;
}
break;
}
case LOCAL_GET:
stack.push(frame.local((int) operands[0]));
break;
Expand Down Expand Up @@ -1844,9 +1852,15 @@ private static void CALL_INDIRECT(
call(stack, instance, callStack, funcId, args, type, false);
}

private static void BLOCK(StackFrame frame, MStack stack) {
frame.isControlFrame = true;
frame.registerStackSize(stack);
private static int numberOfParams(Instance instance, Instruction scope) {
var typeId = (int) scope.operands()[0];
if (typeId == 0x40) { // epsilon
return 0;
}
if (ValueType.isValid(typeId)) {
return 0;
}
return instance.type(typeId).params().size();
}

private static int numberOfValuesToReturn(Instance instance, Instruction scope) {
Expand All @@ -1863,63 +1877,65 @@ private static int numberOfValuesToReturn(Instance instance, Instruction scope)
return instance.type(typeId).returns().size();
}

private static void IF(StackFrame frame, MStack stack, Instruction instruction) {
frame.isControlFrame = false;
frame.registerStackSize(stack);
private static void BLOCK(
StackFrame frame, MStack stack, Instance instance, Instruction instruction) {
var paramsSize = numberOfParams(instance, instruction);
var returnsSize = numberOfValuesToReturn(instance, instruction);
frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize);
}

private static void IF(
StackFrame frame, MStack stack, Instance instance, Instruction instruction) {
var predValue = stack.pop();
var paramsSize = numberOfParams(instance, instruction);
var returnsSize = numberOfValuesToReturn(instance, instruction);
frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize);

frame.jumpTo(predValue.asInt() == 0 ? instruction.labelFalse() : instruction.labelTrue());
}

private static void ctrlJump(StackFrame frame, MStack stack, int n) {
var ctrlFrame = frame.popCtrl(n);
frame.pushCtrl(ctrlFrame);
// a LOOP jumps back to the first instruction without passing through an END
if (ctrlFrame.opCode == OpCode.LOOP) {
StackFrame.doControlTransfer(ctrlFrame, stack);
}
}

private static void BR(StackFrame frame, MStack stack, Instruction instruction) {
checkInterruption();
ctrlJump(frame, stack, (int) instruction.operands()[0]);
frame.jumpTo(instruction.labelTrue());
}

private static void BR_TABLE(StackFrame frame, MStack stack, Instruction instruction) {
var predValue = prepareControlTransfer(frame, stack, true);
var predValue = stack.pop();
var pred = predValue.asInt();

if (pred < 0 || pred >= instruction.labelTable().length - 1) {
var defaultIdx = instruction.operands().length - 1;
if (pred < 0 || pred >= defaultIdx) {
// choose default
frame.jumpTo(instruction.labelTable()[instruction.labelTable().length - 1]);
ctrlJump(frame, stack, (int) instruction.operands()[defaultIdx]);
frame.jumpTo(instruction.labelTable()[defaultIdx]);
} else {
ctrlJump(frame, stack, (int) instruction.operands()[pred]);
frame.jumpTo(instruction.labelTable()[pred]);
}
}

private static void BR_IF(StackFrame frame, MStack stack, Instruction instruction) {
var predValue = prepareControlTransfer(frame, stack, true);
var predValue = stack.pop();
var pred = predValue.asInt();

if (pred == 0) {
frame.jumpTo(instruction.labelFalse());
} else {
ctrlJump(frame, stack, (int) instruction.operands()[0]);
frame.jumpTo(instruction.labelTrue());
}
}

private static Value prepareControlTransfer(StackFrame frame, MStack stack, boolean consume) {
frame.doControlTransfer = true;
frame.isControlFrame = true;
return consume ? stack.pop() : null;
}

private static void doControlTransfer(
Instance instance, MStack stack, StackFrame frame, Instruction scope) {
// reset the control transfer
frame.doControlTransfer = false;

Value[] returns = new Value[numberOfValuesToReturn(instance, scope)];
for (int i = 0; i < returns.length; i++) {
if (stack.size() > 0) returns[i] = stack.pop();
}

// drop everything till the previous label
frame.dropValuesOutOfBlock(stack);

for (int i = 0; i < returns.length; i++) {
Value value = returns[returns.length - 1 - i];
if (value != null) {
stack.push(value);
}
}
}

@Override
public List<StackFrame> getStackTrace() {
return List.copyOf(callStack);
Expand Down
69 changes: 50 additions & 19 deletions runtime/src/main/java/com/dylibso/chicory/runtime/StackFrame.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.dylibso.chicory.runtime;

import com.dylibso.chicory.wasm.types.Instruction;
import com.dylibso.chicory.wasm.types.OpCode;
import com.dylibso.chicory.wasm.types.Value;
import com.dylibso.chicory.wasm.types.ValueType;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand All @@ -19,9 +20,6 @@
* within the function you are in and only specific places.
*/
public class StackFrame {
public boolean doControlTransfer = false;
public boolean isControlFrame = true;

private final List<Instruction> code;
private Instruction currentInstruction;

Expand All @@ -30,7 +28,7 @@ public class StackFrame {
private final Value[] locals;
private final Instance instance;

private final ArrayDeque<Integer> stackSizeBeforeBlock = new ArrayDeque<>();
private final List<CtrlFrame> ctrlStack = new ArrayList<>();

public StackFrame(Instance instance, int funcId, Value[] args, List<ValueType> localTypes) {
this(Collections.emptyList(), instance, funcId, args, localTypes);
Expand Down Expand Up @@ -89,30 +87,63 @@ public boolean terminated() {
return pc >= code.size();
}

public void registerStackSize(MStack stack) {
stackSizeBeforeBlock.push(stack.size());
public void pushCtrl(CtrlFrame ctrlFrame) {
ctrlStack.add(ctrlFrame);
}

public void jumpTo(int newPc) {
pc = newPc;
public void pushCtrl(OpCode opcode, int startValues, int returnValues, int height) {
ctrlStack.add(new CtrlFrame(opcode, startValues, returnValues, height));
}

public CtrlFrame popCtrl() {
var ctrlFrame = ctrlStack.remove(ctrlStack.size() - 1);
return ctrlFrame;
}

public void dropValuesOutOfBlock(MStack stack) {
if (currentInstruction.depth() > 0) {
while (stackSizeBeforeBlock.size() > currentInstruction.depth()) {
stackSizeBeforeBlock.pop();
public CtrlFrame popCtrl(int n) {
int mostRecentCallHeight = ctrlStack.size();
while (true) {
if (ctrlStack.get(--mostRecentCallHeight).opCode == OpCode.CALL) {
break;
}
}
var finalHeight = ctrlStack.size() - (mostRecentCallHeight + n + 1);
CtrlFrame ctrlFrame = null;
while (ctrlStack.size() > finalHeight) {
ctrlFrame = popCtrl();
}
return ctrlFrame;
}

int expectedStackSize = stackSizeBeforeBlock.pop();
while (stack.size() > expectedStackSize) {
stack.pop();
public CtrlFrame popCtrlTillCall() {
while (true) {
var ctrlFrame = popCtrl();
if (ctrlFrame.opCode == OpCode.CALL) {
return ctrlFrame;
}
}
}

public void endOfNonControlBlock() {
if (currentInstruction.depth() > 0) {
stackSizeBeforeBlock.pop();
public void jumpTo(int newPc) {
pc = newPc;
}

public static void doControlTransfer(CtrlFrame ctrlFrame, MStack stack) {
var endResults = ctrlFrame.startValues + ctrlFrame.endValues; // fix 401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// unwind stack

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Value[] returns = new Value[endResults];
for (int i = 0; i < returns.length; i++) {
if (stack.size() > 0) returns[i] = stack.pop();
}

while (stack.size() > ctrlFrame.height) {
stack.pop();
}

for (int i = 0; i < returns.length; i++) {
Value value = returns[returns.length - 1 - i];
if (value != null) {
stack.push(value);
}
}
}
}
21 changes: 21 additions & 0 deletions runtime/src/test/java/com/dylibso/chicory/runtime/ModuleTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,25 @@ public void shouldValidateTypes() {
.build()
.instantiate());
}

@Test
public void shouldConsumeStackLoopOperations() {
AtomicLong finalStackSize = new AtomicLong(0);
var instance =
Module.builder("compiled/fac.wat.wasm")
.withUnsafeExecutionListener(
(Instruction instruction, long[] operands, MStack stack) -> {
finalStackSize.set(stack.size());
})
.build()
.instantiate();
var facSsa = instance.export("fac-ssa");

var number = 100;
var result = facSsa.apply(Value.i32(number));
assertEquals(factorial(number), result[0].asInt());

// IIUC: 3 values returning from last CALL + 1 result
assertTrue(finalStackSize.get() == 4L);
}
}
Binary file added wasm-corpus/src/main/resources/compiled/fac.wat.wasm
Binary file not shown.
Loading
Loading