From 57c1a492cf15e33c85178cc17d40072d1a83240d Mon Sep 17 00:00:00 2001 From: Jonathan Coates Date: Wed, 7 Feb 2024 11:18:36 +0000 Subject: [PATCH] Derive operand registers rather than listing explicitly --- src/main/java/org/squiddev/cobalt/Lua.java | 2 +- .../org/squiddev/cobalt/OperationHelper.java | 146 ++++++++++-------- .../cobalt/function/LuaInterpreter.java | 16 +- src/test/resources/spec/_prelude.lua | 10 +- src/test/resources/spec/operation_spec.lua | 27 ++++ 5 files changed, 123 insertions(+), 78 deletions(-) diff --git a/src/main/java/org/squiddev/cobalt/Lua.java b/src/main/java/org/squiddev/cobalt/Lua.java index 74cfb1a2..b5e3455f 100644 --- a/src/main/java/org/squiddev/cobalt/Lua.java +++ b/src/main/java/org/squiddev/cobalt/Lua.java @@ -264,7 +264,7 @@ private static int opmode(int t, int a, int b, int c, int m) { return (t << 7) | (a << 6) | (b << 4) | (c << 2) | m; } - public static final int[] opmodes = { + private static final int[] opmodes = { opmode(0, 1, OpArgR, OpArgN, iABC), // OP_MOVE opmode(0, 1, OpArgK, OpArgN, iABx), // OP_LOADK opmode(0, 1, OpArgN, OpArgN, iABx), // OP_LOADKX diff --git a/src/main/java/org/squiddev/cobalt/OperationHelper.java b/src/main/java/org/squiddev/cobalt/OperationHelper.java index 717bd6e5..0d155085 100644 --- a/src/main/java/org/squiddev/cobalt/OperationHelper.java +++ b/src/main/java/org/squiddev/cobalt/OperationHelper.java @@ -31,6 +31,7 @@ import org.squiddev.cobalt.function.LuaFunction; import static org.squiddev.cobalt.Constants.*; +import static org.squiddev.cobalt.Lua.*; import static org.squiddev.cobalt.LuaDouble.valueOf; import static org.squiddev.cobalt.LuaInteger.valueOf; import static org.squiddev.cobalt.debug.DebugFrame.FLAG_LEQ; @@ -44,36 +45,24 @@ private OperationHelper() { //region Binary public static LuaValue add(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return add(state, left, right, -1, -1); - } - - public static LuaValue add(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { double dLeft, dRight; if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(dLeft + dRight); } else { - return arithMetatable(state, ADD, left, right, leftIdx, rightIdx); + return arithMetatable(state, ADD, left, right); } } public static LuaValue sub(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return sub(state, left, right, -1, -1); - } - - public static LuaValue sub(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { double dLeft, dRight; if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(dLeft - dRight); } else { - return arithMetatable(state, SUB, left, right, leftIdx, rightIdx); + return arithMetatable(state, SUB, left, right); } } public static LuaValue mul(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return mul(state, left, right, -1, -1); - } - - public static LuaValue mul(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { if (left instanceof LuaInteger l && right instanceof LuaInteger r) { return valueOf((long) l.intValue() * (long) r.intValue()); } @@ -82,46 +71,34 @@ public static LuaValue mul(LuaState state, LuaValue left, LuaValue right, int le if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(dLeft * dRight); } else { - return arithMetatable(state, MUL, left, right, leftIdx, rightIdx); + return arithMetatable(state, MUL, left, right); } } public static LuaValue div(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return div(state, left, right, -1, -1); - } - - public static LuaValue div(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { double dLeft, dRight; if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(div(dLeft, dRight)); } else { - return arithMetatable(state, DIV, left, right, leftIdx, rightIdx); + return arithMetatable(state, DIV, left, right); } } public static LuaValue mod(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return mod(state, left, right, -1, -1); - } - - public static LuaValue mod(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { double dLeft, dRight; if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(mod(dLeft, dRight)); } else { - return arithMetatable(state, MOD, left, right, leftIdx, rightIdx); + return arithMetatable(state, MOD, left, right); } } public static LuaValue pow(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { - return pow(state, left, right, -1, -1); - } - - public static LuaValue pow(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable { double dLeft, dRight; if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) { return valueOf(Math.pow(dLeft, dRight)); } else { - return arithMetatable(state, POW, left, right, leftIdx, rightIdx); + return arithMetatable(state, POW, left, right); } } @@ -155,18 +132,16 @@ public static double mod(double lhs, double rhs) { * Finds the supplied metatag value for {@code this} or {@code op2} and invokes it, * or throws {@link LuaError} if neither is defined. * - * @param state The current lua state - * @param tag The metatag to look up - * @param left The left operand value to perform the operation with - * @param right The other operand value to perform the operation with - * @param leftStack Stack index of the LHS - * @param rightStack Stack index of the RHS + * @param state The current lua state + * @param tag The metatag to look up + * @param left The left operand value to perform the operation with + * @param right The other operand value to perform the operation with * @return {@link LuaValue} resulting from metatag processing * @throws LuaError if metatag was not defined for either operand or the underlying operator errored. * @throws UnwindThrowable If calling the metatable function yielded. */ - public static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError, UnwindThrowable { - return Dispatch.call(state, getMetatable(state, tag, left, right, leftStack, rightStack), left, right); + private static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable { + return Dispatch.call(state, getMetatable(state, tag, left, right), left, right); } /** @@ -175,28 +150,54 @@ public static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue lef * Finds the supplied metatag value for {@code this} or {@code op2} and invokes it, * or throws {@link LuaError} if neither is defined. * - * @param state The current lua state - * @param tag The metatag to look up - * @param left The left operand value to perform the operation with - * @param right The other operand value to perform the operation with - * @param leftStack Stack index of the LHS - * @param rightStack Stack index of the RHS + * @param state The current lua state + * @param tag The metatag to look up + * @param left The left operand value to perform the operation with + * @param right The other operand value to perform the operation with * @return {@link LuaValue} resulting from metatag processing * @throws LuaError if metatag was not defined for either operand */ - public static LuaValue getMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError { + private static LuaValue getMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right) throws LuaError { LuaValue h = left.metatag(state, tag); - if (h.isNil()) { - h = right.metatag(state, tag); - if (h.isNil()) { - if (left.isNumber()) { - left = right; - leftStack = rightStack; - } - throw ErrorFactory.operandError(state, left, "perform arithmetic on", leftStack); + if (!h.isNil()) return h; + + h = right.metatag(state, tag); + if (!h.isNil()) return h; + + + throw createArithmeticError(state, left, right); + } + + private static LuaError createArithmeticError(LuaState state, LuaValue left, LuaValue right) { + // Read the current instruction and try to determine the registers involved. This allows us to avoid passing the + // registers from the interpreter to here. + // PUC Lua just does this by searching the stack for the given value, but that's not possible for us :( + int b = -1, c = -1; + DebugFrame frame = DebugState.get(state).getStack(); + if (frame != null && frame.closure != null) { + var prototype = frame.closure.getPrototype(); + if (frame.pc >= 0 && frame.pc <= prototype.code.length) { + int i = prototype.code[frame.pc]; + assert ( + getOpMode(GET_OPCODE(i)) == iABC && getBMode(GET_OPCODE(i)) == OpArgK && getCMode(GET_OPCODE(i)) == OpArgK + ) : getOpName(GET_OPCODE(i)) + " is not an iABC/RK/RX instruction"; + + b = GETARG_B(i); + c = GETARG_C(i); } } - return h; + + LuaValue value; + int stack; + if (!left.isNumber()) { + value = left; + stack = b; + } else { + value = right; + stack = c; + } + + return ErrorFactory.operandError(state, value, "perform arithmetic on", stack); } public static LuaValue concatNonStrings(LuaState state, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError, UnwindThrowable { @@ -303,10 +304,6 @@ public static boolean eq(LuaState state, LuaValue left, LuaValue right) throws L * @throws UnwindThrowable If the {@code __len} metamethod yielded. */ public static LuaValue length(LuaState state, LuaValue value) throws LuaError, UnwindThrowable { - return length(state, value, -1); - } - - public static LuaValue length(LuaState state, LuaValue value, int stack) throws LuaError, UnwindThrowable { switch (value.type()) { case Constants.TTABLE: { LuaValue h = value.metatag(state, CachedMetamethod.LEN); @@ -320,9 +317,7 @@ public static LuaValue length(LuaState state, LuaValue value, int stack) throws return valueOf(((LuaString) value).length()); default: { LuaValue h = value.metatag(state, CachedMetamethod.LEN); - if (h.isNil()) { - throw ErrorFactory.operandError(state, value, "get length of", stack); - } + if (h.isNil()) throw createUnaryOpError(state, value, "get length of"); return Dispatch.call(state, h, value); } } @@ -351,10 +346,6 @@ public static int intLength(LuaState state, LuaValue table) throws LuaError, Unw * @throws UnwindThrowable If the {@code __unm} metamethod yielded. */ public static LuaValue neg(LuaState state, LuaValue value) throws LuaError, UnwindThrowable { - return neg(state, value, -1); - } - - public static LuaValue neg(LuaState state, LuaValue value, int stack) throws LuaError, UnwindThrowable { int type = value.type(); if (type == TNUMBER) { if (value instanceof LuaInteger) { @@ -369,9 +360,7 @@ public static LuaValue neg(LuaState state, LuaValue value, int stack) throws Lua } LuaValue meta = value.metatag(state, Constants.UNM); - if (meta.isNil()) { - throw ErrorFactory.operandError(state, value, "perform arithmetic on", stack); - } + if (meta.isNil()) throw createUnaryOpError(state, value, "perform arithmetic on"); return Dispatch.call(state, meta, value); } @@ -379,6 +368,27 @@ public static LuaValue neg(LuaState state, LuaValue value, int stack) throws Lua private static boolean checkNumber(LuaValue lua, double value) { return lua.type() == TNUMBER || !Double.isNaN(value); } + + private static LuaError createUnaryOpError(LuaState state, LuaValue value, String message) { + // Read the current instruction and try to determine the register involved. This allows us to avoid passing the + // registers from the interpreter to here. + // PUC Lua just does this by searching the stack for the given value, but that's not possible for us :( + int b = -1; + DebugFrame frame = DebugState.get(state).getStack(); + if (frame != null && frame.closure != null) { + var prototype = frame.closure.getPrototype(); + if (frame.pc >= 0 && frame.pc <= prototype.code.length) { + int i = prototype.code[frame.pc]; + assert ( + getOpMode(GET_OPCODE(i)) == iABC && getBMode(GET_OPCODE(i)) == OpArgR + ) : getOpName(GET_OPCODE(i)) + " is not an iABC/RK instruction"; + + b = GETARG_B(i); + } + } + + return ErrorFactory.operandError(state, value, message, b); + } //endregion //region Tables diff --git a/src/main/java/org/squiddev/cobalt/function/LuaInterpreter.java b/src/main/java/org/squiddev/cobalt/function/LuaInterpreter.java index db713e2a..a44f7530 100644 --- a/src/main/java/org/squiddev/cobalt/function/LuaInterpreter.java +++ b/src/main/java/org/squiddev/cobalt/function/LuaInterpreter.java @@ -265,48 +265,48 @@ static Varargs execute(final LuaState state, DebugFrame di, LuaInterpretedFuncti case OP_ADD: { // A B C: R(A):= RK(B) + RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.add(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.add(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_SUB: { // A B C: R(A):= RK(B) - RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.sub(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.sub(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_MUL: { // A B C: R(A):= RK(B) * RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.mul(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.mul(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_DIV: { // A B C: R(A):= RK(B) / RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.div(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.div(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_MOD: { // A B C: R(A):= RK(B) % RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.mod(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.mod(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_POW: { // A B C: R(A):= RK(B) ^ RK(C) int b = GETARG_B(i); int c = GETARG_C(i); - stack[a] = OperationHelper.pow(state, getRK(stack, k, b), getRK(stack, k, c), b, c); + stack[a] = OperationHelper.pow(state, getRK(stack, k, b), getRK(stack, k, c)); break; } case OP_UNM: { // A B: R(A):= -R(B) int b = GETARG_B(i); - stack[a] = OperationHelper.neg(state, getRK(stack, k, b), b); + stack[a] = OperationHelper.neg(state, getRK(stack, k, b)); break; } @@ -316,7 +316,7 @@ static Varargs execute(final LuaState state, DebugFrame di, LuaInterpretedFuncti case OP_LEN: { // A B: R(A):= length of R(B) int b = GETARG_B(i); - stack[a] = OperationHelper.length(state, stack[b], b); + stack[a] = OperationHelper.length(state, stack[b]); break; } diff --git a/src/test/resources/spec/_prelude.lua b/src/test/resources/spec/_prelude.lua index 35ea8acf..20e20706 100644 --- a/src/test/resources/spec/_prelude.lua +++ b/src/test/resources/spec/_prelude.lua @@ -221,12 +221,20 @@ end --- Add extra information to this error message. -- -- @tparam string message Additional message to prepend in the case of failures. --- @return The current +-- @return The current expect object. function expect_mt:describe(message) self._extra = tostring(message) return self end +--- Remove the position information from an error message. +-- +-- @return The current expect object. +function expect_mt:strip_context() + self.value = self.value:gsub("^[^:]+:%d+: ", "") + return self +end + local expect = {} setmetatable(expect, expect) diff --git a/src/test/resources/spec/operation_spec.lua b/src/test/resources/spec/operation_spec.lua index 7931e365..33b3448d 100644 --- a/src/test/resources/spec/operation_spec.lua +++ b/src/test/resources/spec/operation_spec.lua @@ -13,4 +13,31 @@ describe("Lua's base operators", function() expect(public_key):eq(17511) end) end) + + describe("error messages", function() + local function mk_adder(k) return function() return 2 + k end end + + it("includes upvalue names in error messages :lua>=5.1 :lua<=5.2", function() + expect.error(mk_adder("hello")):strip_context():eq("attempt to perform arithmetic on upvalue 'k' (a string value)") + end) + + it("includes upvalue names in error messages :lua==5.3 :!cobalt", function() + expect.error(mk_adder("hello")):strip_context():eq("attempt to perform arithmetic on a string value (upvalue 'k')") + end) + + local function adder(k) return 2 + k end + + it("includes local names in error messages :lua>=5.1 :lua<=5.2", function() + expect.error(adder, "hello"):strip_context():eq("attempt to perform arithmetic on local 'k' (a string value)") + end) + + it("includes local names in error messages :lua==5.3 :!cobalt", function() + expect.error(adder, "hello"):strip_context():eq("attempt to perform arithmetic on a string value (local 'k')") + end) + + it("includes no information in error messages :lua>=5.4 :!cobalt", function() + expect.error(mk_adder("hello")):strip_context():eq("attempt to add a 'number' with a 'string'") + expect.error(adder, "hello"):strip_context():eq("attempt to add a 'number' with a 'string'") + end) + end) end)