From cb8172b3482205c43e577c303e1b38ae2ee20843 Mon Sep 17 00:00:00 2001 From: Andrea Peruffo Date: Mon, 19 Aug 2024 15:33:29 +0200 Subject: [PATCH] WIP - immutable Instruction --- .../com/dylibso/chicory/aot/AotMachine.java | 69 +++++---- .../chicory/runtime/InterpreterMachine.java | 20 ++- .../com/dylibso/chicory/wasm/ControlTree.java | 13 +- .../java/com/dylibso/chicory/wasm/Parser.java | 83 ++++++---- .../com/dylibso/chicory/wasm/Validator.java | 4 +- .../chicory/wasm/types/Instruction.java | 142 ++++++++++++++---- 6 files changed, 227 insertions(+), 104 deletions(-) diff --git a/aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java b/aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java index a70db7c85..327fc3870 100644 --- a/aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java +++ b/aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java @@ -79,7 +79,12 @@ public final class AotMachine implements Machine { public static final String DEFAULT_CLASS_NAME = "com.dylibso.chicory.$gen.CompiledModule"; - private static final Instruction FUNCTION_SCOPE = new Instruction(-1, OpCode.NOP, new long[0]); + private static final Instruction FUNCTION_SCOPE = + Instruction.builder() + .withAddress(-1) + .withOpcode(OpCode.NOP) + .withOperands(new long[0]) + .build(); private final Module module; private final Instance instance; @@ -642,17 +647,15 @@ private void compileBody( // allocate labels for all label targets Map labels = new HashMap<>(); for (Instruction ins : body.instructions()) { - if (ins.labelTrue() != null) { - labels.put(ins.labelTrue(), new Label()); - } - if (ins.labelFalse() != null) { - labels.put(ins.labelFalse(), new Label()); - } - if (ins.labelTable() != null) { - for (int label : ins.labelTable()) { - labels.put(label, new Label()); - } - } + ins.labelTrue().ifPresent(l -> labels.put(l, new Label())); + ins.labelFalse().ifPresent(l -> labels.put(l, new Label())); + ins.labelTable() + .ifPresent( + l -> { + for (int label : l) { + labels.put(label, new Label()); + } + }); } // fake instruction to use for the function's implicit block @@ -688,10 +691,10 @@ private void compileBody( break; case BLOCK: case LOOP: - ctx.enterScope(ins.scope(), blockType(ins)); + ctx.enterScope(ins.scope().get(), blockType(ins)); break; case END: - ctx.exitScope(ins.scope()); + ctx.exitScope(ins.scope().get()); break; case UNREACHABLE: exitBlockDepth = ins.depth(); @@ -704,34 +707,35 @@ private void compileBody( break; case IF: ctx.popStackSize(); - ctx.enterScope(ins.scope(), blockType(ins)); - asm.visitJumpInsn(Opcodes.IFEQ, labels.get(ins.labelFalse())); + ctx.enterScope(ins.scope().get(), blockType(ins)); + asm.visitJumpInsn(Opcodes.IFEQ, labels.get(ins.labelFalse().get())); // use the same starting stack sizes for both sides of the branch - if (body.instructions().get(ins.labelFalse() - 1).opcode() == OpCode.ELSE) { + if (body.instructions().get(ins.labelFalse().get() - 1).opcode() + == OpCode.ELSE) { ctx.pushStackSizesStack(); } break; case ELSE: - asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue())); + asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue().get())); ctx.popStackSizesStack(); break; case BR: exitBlockDepth = ins.depth(); - if (ins.labelTrue() < idx) { + if (ins.labelTrue().get() < idx) { emitInvokeStatic(asm, CHECK_INTERRUPTION); } - emitUnwindStack(asm, type, body, ins, ins.labelTrue(), ctx); - asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue())); + emitUnwindStack(asm, type, body, ins, ins.labelTrue().get(), ctx); + asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue().get())); break; case BR_IF: ctx.popStackSize(); Label falseLabel = new Label(); asm.visitJumpInsn(Opcodes.IFEQ, falseLabel); - if (ins.labelTrue() < idx) { + if (ins.labelTrue().get() < idx) { emitInvokeStatic(asm, CHECK_INTERRUPTION); } - emitUnwindStack(asm, type, body, ins, ins.labelTrue(), ctx); - asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue())); + emitUnwindStack(asm, type, body, ins, ins.labelTrue().get(), ctx); + asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTrue().get())); asm.visitLabel(falseLabel); break; case BR_TABLE: @@ -739,20 +743,23 @@ private void compileBody( ctx.popStackSize(); emitInvokeStatic(asm, CHECK_INTERRUPTION); // skip table switch if it only has a default - if (ins.labelTable().length == 1) { + if (ins.labelTable().get().size() == 1) { asm.visitInsn(Opcodes.POP); - emitUnwindStack(asm, type, body, ins, ins.labelTable()[0], ctx); - asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTable()[0])); + emitUnwindStack(asm, type, body, ins, ins.labelTable().get().get(0), ctx); + asm.visitJumpInsn(Opcodes.GOTO, labels.get(ins.labelTable().get().get(0))); break; } // collect unique target labels Map targets = new HashMap<>(); - Label[] table = new Label[ins.labelTable().length - 1]; + Label[] table = new Label[ins.labelTable().get().size() - 1]; for (int i = 0; i < table.length; i++) { - table[i] = targets.computeIfAbsent(ins.labelTable()[i], x -> new Label()); + table[i] = + targets.computeIfAbsent( + ins.labelTable().get().get(i), x -> new Label()); } // table switch using the last entry of the label table as the default - int defaultTarget = ins.labelTable()[ins.labelTable().length - 1]; + int defaultTarget = + ins.labelTable().get().get(ins.labelTable().get().size() - 1); Label defaultLabel = targets.computeIfAbsent(defaultTarget, x -> new Label()); asm.visitTableSwitchInsn(0, table.length - 1, defaultLabel, table); // generate separate unwinds for each target @@ -814,7 +821,7 @@ private void emitUnwindStack( target = body.instructions().get(label - 1); forward = false; } - var scope = target.scope(); + var scope = target.scope().get(); FunctionType blockType; if (scope.opcode() == OpCode.END) { diff --git a/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java b/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java index e9fd2a5d6..ba5f8cff6 100644 --- a/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java +++ b/runtime/src/main/java/com/dylibso/chicory/runtime/InterpreterMachine.java @@ -133,7 +133,7 @@ static void eval(MStack stack, Instance instance, ArrayDeque callSta IF(frame, stack, instance, instruction); break; case ELSE: - frame.jumpTo(instruction.labelTrue()); + frame.jumpTo(instruction.labelTrue().get()); break; case BR: BR(frame, stack, instruction); @@ -1883,7 +1883,10 @@ private static void IF( var returnsSize = numberOfValuesToReturn(instance, instruction); frame.pushCtrl(instruction.opcode(), paramsSize, returnsSize, stack.size() - paramsSize); - frame.jumpTo(predValue.asInt() == 0 ? instruction.labelFalse() : instruction.labelTrue()); + frame.jumpTo( + predValue.asInt() == 0 + ? instruction.labelFalse().get() + : instruction.labelTrue().get()); } private static void ctrlJump(StackFrame frame, MStack stack, int n) { @@ -1898,7 +1901,7 @@ private static void ctrlJump(StackFrame frame, MStack stack, int n) { private static void BR(StackFrame frame, MStack stack, Instruction instruction) { checkInterruption(); ctrlJump(frame, stack, (int) instruction.operands()[0]); - frame.jumpTo(instruction.labelTrue()); + frame.jumpTo(instruction.labelTrue().get()); } private static void BR_TABLE(StackFrame frame, MStack stack, Instruction instruction) { @@ -1909,10 +1912,10 @@ private static void BR_TABLE(StackFrame frame, MStack stack, Instruction instruc if (pred < 0 || pred >= defaultIdx) { // choose default ctrlJump(frame, stack, (int) instruction.operands()[defaultIdx]); - frame.jumpTo(instruction.labelTable()[defaultIdx]); + frame.jumpTo(instruction.labelTable().get().get(defaultIdx)); } else { ctrlJump(frame, stack, (int) instruction.operands()[pred]); - frame.jumpTo(instruction.labelTable()[pred]); + frame.jumpTo(instruction.labelTable().get().get(pred)); } } @@ -1921,10 +1924,13 @@ private static void BR_IF(StackFrame frame, MStack stack, Instruction instructio var pred = predValue.asInt(); if (pred == 0) { - frame.jumpTo(instruction.labelFalse()); + frame.jumpTo(instruction.labelFalse().get()); } else { ctrlJump(frame, stack, (int) instruction.operands()[0]); - frame.jumpTo(instruction.labelTrue()); + if (instruction.labelTrue().isEmpty()) { + System.out.println("debug me"); + } + frame.jumpTo(instruction.labelTrue().get()); } } diff --git a/wasm/src/main/java/com/dylibso/chicory/wasm/ControlTree.java b/wasm/src/main/java/com/dylibso/chicory/wasm/ControlTree.java index f5ffa7951..780ab6a57 100644 --- a/wasm/src/main/java/com/dylibso/chicory/wasm/ControlTree.java +++ b/wasm/src/main/java/com/dylibso/chicory/wasm/ControlTree.java @@ -51,7 +51,7 @@ * */ final class ControlTree { - private final Instruction instruction; + private final Instruction.Builder instruction; private final int initialInstructionNumber; private int finalInstructionNumber = -1; // to be set when END is reached private final ControlTree parent; @@ -66,7 +66,8 @@ public ControlTree() { this.callbacks = new ArrayList<>(); } - private ControlTree(int initialInstructionNumber, Instruction instruction, ControlTree parent) { + private ControlTree( + int initialInstructionNumber, Instruction.Builder instruction, ControlTree parent) { this.instruction = instruction; this.initialInstructionNumber = initialInstructionNumber; this.parent = parent; @@ -74,7 +75,7 @@ private ControlTree(int initialInstructionNumber, Instruction instruction, Contr this.callbacks = new ArrayList<>(); } - public ControlTree spawn(int initialInstructionNumber, Instruction instruction) { + public ControlTree spawn(int initialInstructionNumber, Instruction.Builder instruction) { var node = new ControlTree(initialInstructionNumber, instruction, this); this.addNested(node); return node; @@ -84,7 +85,7 @@ public boolean isRoot() { return this.parent == null; } - public Instruction instruction() { + public Instruction.Builder instruction() { return instruction; } @@ -108,10 +109,10 @@ public void addCallback(Consumer callback) { this.callbacks.add(callback); } - public void setFinalInstructionNumber(int finalInstructionNumber, Instruction end) { + public void setFinalInstructionNumber(int finalInstructionNumber, Instruction.Builder end) { this.finalInstructionNumber = finalInstructionNumber; - if (end.scope().opcode() == OpCode.LOOP) { + if (end.scope().isPresent() && end.scope().get().opcode() == OpCode.LOOP) { var lastLoopInstruction = 0; for (var ct : this.parent.nested) { if (ct.instruction().opcode() == OpCode.LOOP) { diff --git a/wasm/src/main/java/com/dylibso/chicory/wasm/Parser.java b/wasm/src/main/java/com/dylibso/chicory/wasm/Parser.java index 6b1c20695..46983d98b 100644 --- a/wasm/src/main/java/com/dylibso/chicory/wasm/Parser.java +++ b/wasm/src/main/java/com/dylibso/chicory/wasm/Parser.java @@ -67,6 +67,7 @@ import java.util.Map; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; /** * Parser for Web Assembly binaries. @@ -682,9 +683,16 @@ private static Element parseSingleElement(ByteBuffer buffer) { for (int i = 0; i < initCnt; i++) { inits.add( List.of( - new Instruction( - -1, OpCode.REF_FUNC, new long[] {readVarUInt32(buffer)}), - new Instruction(-1, OpCode.END, new long[0]))); + Instruction.builder() + .withAddress(-1) + .withOpcode(OpCode.REF_FUNC) + .withOperands(new long[] {readVarUInt32(buffer)}) + .build(), + Instruction.builder() + .withAddress(-1) + .withOpcode(OpCode.END) + .withOperands(new long[0]) + .build())); } } if (declarative) { @@ -727,7 +735,7 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { var depth = 0; var funcEndPoint = readVarUInt32(buffer) + buffer.position(); var locals = parseCodeSectionLocalTypes(buffer); - var instructions = new ArrayList(); + var instructions = new ArrayList(); var lastInstruction = false; ControlTree currentControlFlow = null; @@ -752,22 +760,22 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { case IF: { depth++; - instruction.setDepth(depth); - blockScope.push(instruction); - instruction.setScope(blockScope.peek()); + instruction.withDepth(depth); + blockScope.push(instruction.build()); + instruction.withScope(blockScope.peek()); break; } case END: { - instruction.setDepth(depth); + instruction.withDepth(depth); depth--; - instruction.setScope( - blockScope.isEmpty() ? instruction : blockScope.pop()); + instruction.withScope( + blockScope.isEmpty() ? instruction.build() : blockScope.pop()); break; } default: { - instruction.setDepth(depth); + instruction.withDepth(depth); break; } } @@ -790,28 +798,31 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { currentControlFlow.addCallback( end -> { // check that there is no "else" branch - if (instruction.labelFalse() == defaultJmp) { - instruction.setLabelFalse(end); + if (instruction.labelFalse().isPresent() + && instruction.labelFalse().get() == defaultJmp) { + instruction.withLabelFalse(end); } }); // defaults - instruction.setLabelTrue(defaultJmp); - instruction.setLabelFalse(defaultJmp); + instruction.withLabelTrue(defaultJmp); + instruction.withLabelFalse(defaultJmp); break; } case ELSE: { assert (currentControlFlow.instruction().opcode() == OpCode.IF); - currentControlFlow.instruction().setLabelFalse(instructions.size() + 1); + currentControlFlow + .instruction() + .withLabelFalse(instructions.size() + 1); - currentControlFlow.addCallback(instruction::setLabelTrue); + currentControlFlow.addCallback(instruction::withLabelTrue); break; } case BR_IF: { - instruction.setLabelFalse(instructions.size() + 1); + instruction.withLabelFalse(instructions.size() + 1); } // fallthrough case BR: @@ -825,13 +836,15 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { reference = reference.parent(); offset--; } - reference.addCallback(instruction::setLabelTrue); + reference.addCallback(instruction::withLabelTrue); break; } case BR_TABLE: { - instruction.setLabelTable(new int[instruction.operands().length]); - for (var idx = 0; idx < instruction.labelTable().length; idx++) { + var length = instruction.operands().length; + var labelTable = new ArrayList(); + for (var idx = 0; idx < length; idx++) { + labelTable.add(null); var offset = (int) instruction.operands()[idx]; ControlTree reference = currentControlFlow; while (offset > 0) { @@ -842,9 +855,9 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { offset--; } int finalIdx = idx; - reference.addCallback( - end -> instruction.labelTable()[finalIdx] = end); + reference.addCallback(end -> labelTable.set(finalIdx, end)); } + instruction.withLabelTable(labelTable); break; } case END: @@ -856,7 +869,7 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { if (lastInstruction && instructions.size() > 1) { var former = instructions.get(instructions.size() - 1); if (former.opcode() == OpCode.END) { - instruction.setScope(former.scope()); + instruction.withScope(former.scope()); } } break; @@ -869,7 +882,12 @@ private static CodeSection parseCodeSection(ByteBuffer buffer) { instructions.add(instruction); } while (!lastInstruction); - var functionBody = new FunctionBody(locals, instructions); + var functionBody = + new FunctionBody( + locals, + instructions.stream() + .map(ins -> ins.build()) + .collect(Collectors.toUnmodifiableList())); codeSection.addFunctionBody(functionBody); } @@ -911,7 +929,7 @@ private static DataCountSection parseDataCountSection(ByteBuffer buffer) { return DataCountSection.builder().withDataCount((int) dataCount).build(); } - private static Instruction parseInstruction(ByteBuffer buffer) { + private static Instruction.Builder parseInstruction(ByteBuffer buffer) { var address = buffer.position(); var b = (int) readByte(buffer) & 0xff; @@ -940,7 +958,10 @@ private static Instruction parseInstruction(ByteBuffer buffer) { break; } if (signature.length == 0) { - return new Instruction(address, op, new long[] {}); + return Instruction.builder() + .withAddress(address) + .withOpcode(op) + .withOperands(new long[] {}); } var operands = new ArrayList(); for (var sig : signature) { @@ -975,7 +996,10 @@ private static Instruction parseInstruction(ByteBuffer buffer) { operandsArray[i] = operands.get(i); } verifyAlignment(op, operandsArray); - return new Instruction(address, op, operandsArray); + return Instruction.builder() + .withAddress(address) + .withOpcode(op) + .withOperands(operandsArray); } private static void verifyAlignment(OpCode op, long[] operands) { @@ -1020,14 +1044,13 @@ private static void verifyAlignment(OpCode op, long[] operands) { } private static Instruction[] parseExpression(ByteBuffer buffer) { - var expr = new ArrayList(); while (true) { var i = parseInstruction(buffer); if (i.opcode() == OpCode.END) { break; } - expr.add(i); + expr.add(i.build()); } return expr.toArray(new Instruction[0]); } diff --git a/wasm/src/main/java/com/dylibso/chicory/wasm/Validator.java b/wasm/src/main/java/com/dylibso/chicory/wasm/Validator.java index 463903eea..9f08a14b8 100644 --- a/wasm/src/main/java/com/dylibso/chicory/wasm/Validator.java +++ b/wasm/src/main/java/com/dylibso/chicory/wasm/Validator.java @@ -397,7 +397,7 @@ public void validateFunction(int funcIdx, FunctionBody body, FunctionType functi case BR: { var n = (int) op.operands()[0]; - if (op.labelTrue() == null) { + if (op.labelTrue().isEmpty()) { throw new InvalidException("unknown label " + n); } popVals(labelTypes(getCtrl(n))); @@ -408,7 +408,7 @@ public void validateFunction(int funcIdx, FunctionBody body, FunctionType functi { popVal(ValueType.I32); var n = (int) op.operands()[0]; - if (op.labelTrue() == null) { + if (op.labelTrue().isEmpty()) { throw new InvalidException("unknown label " + n); } var labelTypes = labelTypes(getCtrl(n)); diff --git a/wasm/src/main/java/com/dylibso/chicory/wasm/types/Instruction.java b/wasm/src/main/java/com/dylibso/chicory/wasm/types/Instruction.java index b2aac94d8..9e355e783 100644 --- a/wasm/src/main/java/com/dylibso/chicory/wasm/types/Instruction.java +++ b/wasm/src/main/java/com/dylibso/chicory/wasm/types/Instruction.java @@ -1,22 +1,37 @@ package com.dylibso.chicory.wasm.types; import java.util.Arrays; +import java.util.List; +import java.util.Optional; public class Instruction { private final int address; private final OpCode opcode; private final long[] operands; // metadata fields - private Integer labelTrue; - private Integer labelFalse; - private int[] labelTable; - private int depth; - private Instruction scope; - - public Instruction(int address, OpCode opcode, long[] operands) { + private final int depth; + private final Optional labelTrue; + private final Optional labelFalse; + private final Optional> labelTable; + private final Optional scope; + + private Instruction( + int address, + OpCode opcode, + long[] operands, + int depth, + Optional labelTrue, + Optional labelFalse, + Optional> labelTable, + Optional scope) { this.address = address; this.opcode = opcode; this.operands = operands; + this.depth = depth; + this.labelTrue = labelTrue; + this.labelFalse = labelFalse; + this.labelTable = labelTable; + this.scope = scope; } public OpCode opcode() { @@ -40,43 +55,114 @@ public int address() { return address; } - public Integer labelTrue() { + public Optional labelTrue() { return labelTrue; } - public void setLabelTrue(Integer labelTrue) { - this.labelTrue = labelTrue; - } - - public Integer labelFalse() { + public Optional labelFalse() { return labelFalse; } - public void setLabelFalse(Integer labelFalse) { - this.labelFalse = labelFalse; - } - - public int[] labelTable() { + public Optional> labelTable() { return labelTable; } - public void setLabelTable(int[] labelTable) { - this.labelTable = labelTable; - } - public int depth() { return depth; } - public void setDepth(int depth) { - this.depth = depth; + public Optional scope() { + return scope; } - public Instruction scope() { - return scope; + public static Builder builder() { + return new Builder(); } - public void setScope(Instruction scope) { - this.scope = scope; + public static class Builder { + private int address; + private OpCode opcode; + private long[] operands; + private int depth; + private Optional labelTrue = Optional.empty(); + private Optional labelFalse = Optional.empty(); + private Optional> labelTable = Optional.empty(); + private Optional scope = Optional.empty(); + + private Builder() {} + + public OpCode opcode() { + return opcode; + } + + public Optional labelTrue() { + return labelTrue; + } + + public Optional labelFalse() { + return labelFalse; + } + + public Optional> labelTable() { + return labelTable; + } + + public long[] operands() { + return operands; + } + + public Optional scope() { + return scope; + } + + public Builder withAddress(int address) { + this.address = address; + return this; + } + + public Builder withOpcode(OpCode opcode) { + this.opcode = opcode; + return this; + } + + public Builder withOperands(long[] operands) { + this.operands = operands; + return this; + } + + public Builder withDepth(int depth) { + this.depth = depth; + return this; + } + + public Builder withLabelTrue(int label) { + this.labelTrue = Optional.of(label); + return this; + } + + public Builder withLabelFalse(int label) { + this.labelFalse = Optional.of(label); + return this; + } + + public Builder withLabelTable(List labelTable) { + this.labelTable = Optional.of(labelTable); + return this; + } + + public Builder withScope(Instruction scope) { + this.scope = Optional.of(scope); + return this; + } + + public Builder withScope(Optional scope) { + this.scope = scope; + return this; + } + + public Instruction build() { + return new Instruction( + address, opcode, operands, depth, labelTrue, labelFalse, labelTable, scope); + } } }