Skip to content

Commit

Permalink
Derive operand registers rather than listing explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
SquidDev committed Feb 15, 2024
1 parent 4139b54 commit 57c1a49
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/squiddev/cobalt/Lua.java
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ private static int opmode(int t, int a, int b, int c, int m) {
return (t << 7) | (a << 6) | (b << 4) | (c << 2) | m;
}

public static final int[] opmodes = {
private static final int[] opmodes = {
opmode(0, 1, OpArgR, OpArgN, iABC), // OP_MOVE
opmode(0, 1, OpArgK, OpArgN, iABx), // OP_LOADK
opmode(0, 1, OpArgN, OpArgN, iABx), // OP_LOADKX
Expand Down
146 changes: 78 additions & 68 deletions src/main/java/org/squiddev/cobalt/OperationHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.squiddev.cobalt.function.LuaFunction;

import static org.squiddev.cobalt.Constants.*;
import static org.squiddev.cobalt.Lua.*;
import static org.squiddev.cobalt.LuaDouble.valueOf;
import static org.squiddev.cobalt.LuaInteger.valueOf;
import static org.squiddev.cobalt.debug.DebugFrame.FLAG_LEQ;
Expand All @@ -44,36 +45,24 @@ private OperationHelper() {

//region Binary
public static LuaValue add(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return add(state, left, right, -1, -1);
}

public static LuaValue add(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
double dLeft, dRight;
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(dLeft + dRight);
} else {
return arithMetatable(state, ADD, left, right, leftIdx, rightIdx);
return arithMetatable(state, ADD, left, right);
}
}

public static LuaValue sub(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return sub(state, left, right, -1, -1);
}

public static LuaValue sub(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
double dLeft, dRight;
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(dLeft - dRight);
} else {
return arithMetatable(state, SUB, left, right, leftIdx, rightIdx);
return arithMetatable(state, SUB, left, right);
}
}

public static LuaValue mul(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return mul(state, left, right, -1, -1);
}

public static LuaValue mul(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
if (left instanceof LuaInteger l && right instanceof LuaInteger r) {
return valueOf((long) l.intValue() * (long) r.intValue());
}
Expand All @@ -82,46 +71,34 @@ public static LuaValue mul(LuaState state, LuaValue left, LuaValue right, int le
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(dLeft * dRight);
} else {
return arithMetatable(state, MUL, left, right, leftIdx, rightIdx);
return arithMetatable(state, MUL, left, right);
}
}

public static LuaValue div(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return div(state, left, right, -1, -1);
}

public static LuaValue div(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
double dLeft, dRight;
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(div(dLeft, dRight));
} else {
return arithMetatable(state, DIV, left, right, leftIdx, rightIdx);
return arithMetatable(state, DIV, left, right);
}
}

public static LuaValue mod(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return mod(state, left, right, -1, -1);
}

public static LuaValue mod(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
double dLeft, dRight;
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(mod(dLeft, dRight));
} else {
return arithMetatable(state, MOD, left, right, leftIdx, rightIdx);
return arithMetatable(state, MOD, left, right);
}
}

public static LuaValue pow(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return pow(state, left, right, -1, -1);
}

public static LuaValue pow(LuaState state, LuaValue left, LuaValue right, int leftIdx, int rightIdx) throws LuaError, UnwindThrowable {
double dLeft, dRight;
if (checkNumber(left, dLeft = left.toDouble()) && checkNumber(right, dRight = right.toDouble())) {
return valueOf(Math.pow(dLeft, dRight));
} else {
return arithMetatable(state, POW, left, right, leftIdx, rightIdx);
return arithMetatable(state, POW, left, right);
}
}

Expand Down Expand Up @@ -155,18 +132,16 @@ public static double mod(double lhs, double rhs) {
* Finds the supplied metatag value for {@code this} or {@code op2} and invokes it,
* or throws {@link LuaError} if neither is defined.
*
* @param state The current lua state
* @param tag The metatag to look up
* @param left The left operand value to perform the operation with
* @param right The other operand value to perform the operation with
* @param leftStack Stack index of the LHS
* @param rightStack Stack index of the RHS
* @param state The current lua state
* @param tag The metatag to look up
* @param left The left operand value to perform the operation with
* @param right The other operand value to perform the operation with
* @return {@link LuaValue} resulting from metatag processing
* @throws LuaError if metatag was not defined for either operand or the underlying operator errored.
* @throws UnwindThrowable If calling the metatable function yielded.
*/
public static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError, UnwindThrowable {
return Dispatch.call(state, getMetatable(state, tag, left, right, leftStack, rightStack), left, right);
private static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
return Dispatch.call(state, getMetatable(state, tag, left, right), left, right);
}

/**
Expand All @@ -175,28 +150,54 @@ public static LuaValue arithMetatable(LuaState state, LuaValue tag, LuaValue lef
* Finds the supplied metatag value for {@code this} or {@code op2} and invokes it,
* or throws {@link LuaError} if neither is defined.
*
* @param state The current lua state
* @param tag The metatag to look up
* @param left The left operand value to perform the operation with
* @param right The other operand value to perform the operation with
* @param leftStack Stack index of the LHS
* @param rightStack Stack index of the RHS
* @param state The current lua state
* @param tag The metatag to look up
* @param left The left operand value to perform the operation with
* @param right The other operand value to perform the operation with
* @return {@link LuaValue} resulting from metatag processing
* @throws LuaError if metatag was not defined for either operand
*/
public static LuaValue getMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError {
private static LuaValue getMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right) throws LuaError {
LuaValue h = left.metatag(state, tag);
if (h.isNil()) {
h = right.metatag(state, tag);
if (h.isNil()) {
if (left.isNumber()) {
left = right;
leftStack = rightStack;
}
throw ErrorFactory.operandError(state, left, "perform arithmetic on", leftStack);
if (!h.isNil()) return h;

h = right.metatag(state, tag);
if (!h.isNil()) return h;


throw createArithmeticError(state, left, right);
}

private static LuaError createArithmeticError(LuaState state, LuaValue left, LuaValue right) {
// Read the current instruction and try to determine the registers involved. This allows us to avoid passing the
// registers from the interpreter to here.
// PUC Lua just does this by searching the stack for the given value, but that's not possible for us :(
int b = -1, c = -1;
DebugFrame frame = DebugState.get(state).getStack();
if (frame != null && frame.closure != null) {
var prototype = frame.closure.getPrototype();
if (frame.pc >= 0 && frame.pc <= prototype.code.length) {
int i = prototype.code[frame.pc];
assert (
getOpMode(GET_OPCODE(i)) == iABC && getBMode(GET_OPCODE(i)) == OpArgK && getCMode(GET_OPCODE(i)) == OpArgK
) : getOpName(GET_OPCODE(i)) + " is not an iABC/RK/RX instruction";

b = GETARG_B(i);
c = GETARG_C(i);
}
}
return h;

LuaValue value;
int stack;
if (!left.isNumber()) {
value = left;
stack = b;
} else {
value = right;
stack = c;
}

return ErrorFactory.operandError(state, value, "perform arithmetic on", stack);
}

public static LuaValue concatNonStrings(LuaState state, LuaValue left, LuaValue right, int leftStack, int rightStack) throws LuaError, UnwindThrowable {
Expand Down Expand Up @@ -303,10 +304,6 @@ public static boolean eq(LuaState state, LuaValue left, LuaValue right) throws L
* @throws UnwindThrowable If the {@code __len} metamethod yielded.
*/
public static LuaValue length(LuaState state, LuaValue value) throws LuaError, UnwindThrowable {
return length(state, value, -1);
}

public static LuaValue length(LuaState state, LuaValue value, int stack) throws LuaError, UnwindThrowable {
switch (value.type()) {
case Constants.TTABLE: {
LuaValue h = value.metatag(state, CachedMetamethod.LEN);
Expand All @@ -320,9 +317,7 @@ public static LuaValue length(LuaState state, LuaValue value, int stack) throws
return valueOf(((LuaString) value).length());
default: {
LuaValue h = value.metatag(state, CachedMetamethod.LEN);
if (h.isNil()) {
throw ErrorFactory.operandError(state, value, "get length of", stack);
}
if (h.isNil()) throw createUnaryOpError(state, value, "get length of");
return Dispatch.call(state, h, value);
}
}
Expand Down Expand Up @@ -351,10 +346,6 @@ public static int intLength(LuaState state, LuaValue table) throws LuaError, Unw
* @throws UnwindThrowable If the {@code __unm} metamethod yielded.
*/
public static LuaValue neg(LuaState state, LuaValue value) throws LuaError, UnwindThrowable {
return neg(state, value, -1);
}

public static LuaValue neg(LuaState state, LuaValue value, int stack) throws LuaError, UnwindThrowable {
int type = value.type();
if (type == TNUMBER) {
if (value instanceof LuaInteger) {
Expand All @@ -369,16 +360,35 @@ public static LuaValue neg(LuaState state, LuaValue value, int stack) throws Lua
}

LuaValue meta = value.metatag(state, Constants.UNM);
if (meta.isNil()) {
throw ErrorFactory.operandError(state, value, "perform arithmetic on", stack);
}
if (meta.isNil()) throw createUnaryOpError(state, value, "perform arithmetic on");

return Dispatch.call(state, meta, value);
}

private static boolean checkNumber(LuaValue lua, double value) {
return lua.type() == TNUMBER || !Double.isNaN(value);
}

private static LuaError createUnaryOpError(LuaState state, LuaValue value, String message) {
// Read the current instruction and try to determine the register involved. This allows us to avoid passing the
// registers from the interpreter to here.
// PUC Lua just does this by searching the stack for the given value, but that's not possible for us :(
int b = -1;
DebugFrame frame = DebugState.get(state).getStack();
if (frame != null && frame.closure != null) {
var prototype = frame.closure.getPrototype();
if (frame.pc >= 0 && frame.pc <= prototype.code.length) {
int i = prototype.code[frame.pc];
assert (
getOpMode(GET_OPCODE(i)) == iABC && getBMode(GET_OPCODE(i)) == OpArgR
) : getOpName(GET_OPCODE(i)) + " is not an iABC/RK instruction";

b = GETARG_B(i);
}
}

return ErrorFactory.operandError(state, value, message, b);
}
//endregion

//region Tables
Expand Down
16 changes: 8 additions & 8 deletions src/main/java/org/squiddev/cobalt/function/LuaInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -265,48 +265,48 @@ static Varargs execute(final LuaState state, DebugFrame di, LuaInterpretedFuncti
case OP_ADD: { // A B C: R(A):= RK(B) + RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.add(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.add(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_SUB: { // A B C: R(A):= RK(B) - RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.sub(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.sub(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_MUL: { // A B C: R(A):= RK(B) * RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.mul(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.mul(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_DIV: { // A B C: R(A):= RK(B) / RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.div(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.div(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_MOD: { // A B C: R(A):= RK(B) % RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.mod(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.mod(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_POW: { // A B C: R(A):= RK(B) ^ RK(C)
int b = GETARG_B(i);
int c = GETARG_C(i);
stack[a] = OperationHelper.pow(state, getRK(stack, k, b), getRK(stack, k, c), b, c);
stack[a] = OperationHelper.pow(state, getRK(stack, k, b), getRK(stack, k, c));
break;
}

case OP_UNM: { // A B: R(A):= -R(B)
int b = GETARG_B(i);
stack[a] = OperationHelper.neg(state, getRK(stack, k, b), b);
stack[a] = OperationHelper.neg(state, getRK(stack, k, b));
break;
}

Expand All @@ -316,7 +316,7 @@ static Varargs execute(final LuaState state, DebugFrame di, LuaInterpretedFuncti

case OP_LEN: { // A B: R(A):= length of R(B)
int b = GETARG_B(i);
stack[a] = OperationHelper.length(state, stack[b], b);
stack[a] = OperationHelper.length(state, stack[b]);
break;
}

Expand Down
10 changes: 9 additions & 1 deletion src/test/resources/spec/_prelude.lua
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,20 @@ end
--- Add extra information to this error message.
--
-- @tparam string message Additional message to prepend in the case of failures.
-- @return The current
-- @return The current expect object.
function expect_mt:describe(message)
self._extra = tostring(message)
return self
end

--- Remove the position information from an error message.
--
-- @return The current expect object.
function expect_mt:strip_context()
self.value = self.value:gsub("^[^:]+:%d+: ", "")
return self
end

local expect = {}
setmetatable(expect, expect)

Expand Down
27 changes: 27 additions & 0 deletions src/test/resources/spec/operation_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,31 @@ describe("Lua's base operators", function()
expect(public_key):eq(17511)
end)
end)

describe("error messages", function()
local function mk_adder(k) return function() return 2 + k end end

it("includes upvalue names in error messages :lua>=5.1 :lua<=5.2", function()
expect.error(mk_adder("hello")):strip_context():eq("attempt to perform arithmetic on upvalue 'k' (a string value)")
end)

it("includes upvalue names in error messages :lua==5.3 :!cobalt", function()
expect.error(mk_adder("hello")):strip_context():eq("attempt to perform arithmetic on a string value (upvalue 'k')")
end)

local function adder(k) return 2 + k end

it("includes local names in error messages :lua>=5.1 :lua<=5.2", function()
expect.error(adder, "hello"):strip_context():eq("attempt to perform arithmetic on local 'k' (a string value)")
end)

it("includes local names in error messages :lua==5.3 :!cobalt", function()
expect.error(adder, "hello"):strip_context():eq("attempt to perform arithmetic on a string value (local 'k')")
end)

it("includes no information in error messages :lua>=5.4 :!cobalt", function()
expect.error(mk_adder("hello")):strip_context():eq("attempt to add a 'number' with a 'string'")
expect.error(adder, "hello"):strip_context():eq("attempt to add a 'number' with a 'string'")
end)
end)
end)

0 comments on commit 57c1a49

Please sign in to comment.