From 9f03b7cd5df705cd8d4f31eae2785554511200c1 Mon Sep 17 00:00:00 2001 From: Jonathan Coates Date: Wed, 1 Nov 2023 21:35:29 +0000 Subject: [PATCH] Update to Lua 5.2 Yes, I could do this in multiple commits. Or I could chuck it in one horrifying mess! --- build.gradle.kts | 4 +- .../cc/tweaked/cobalt/internal/LegacyEnv.java | 55 ++ .../java/org/squiddev/cobalt/Constants.java | 5 + src/main/java/org/squiddev/cobalt/Lua.java | 137 ++-- .../java/org/squiddev/cobalt/LuaState.java | 13 +- .../java/org/squiddev/cobalt/LuaThread.java | 26 +- .../java/org/squiddev/cobalt/LuaValue.java | 24 - src/main/java/org/squiddev/cobalt/Print.java | 65 +- .../java/org/squiddev/cobalt/Prototype.java | 35 +- .../cobalt/compiler/BytecodeDumper.java | 79 +- .../cobalt/compiler/BytecodeLoader.java | 130 ++- .../org/squiddev/cobalt/compiler/ExpKind.java | 21 +- .../squiddev/cobalt/compiler/FuncState.java | 670 +++++++++------ .../org/squiddev/cobalt/compiler/Lex.java | 27 +- .../squiddev/cobalt/compiler/LoadState.java | 13 +- .../cobalt/compiler/LuaBytecodeFormat.java | 10 +- .../org/squiddev/cobalt/compiler/LuaC.java | 22 +- .../org/squiddev/cobalt/compiler/Parser.java | 769 ++++++++++++------ .../squiddev/cobalt/debug/DebugHelpers.java | 17 +- .../cobalt/function/HasEnvironment.java | 2 - .../squiddev/cobalt/function/LibFunction.java | 3 - .../squiddev/cobalt/function/LuaClosure.java | 5 - .../squiddev/cobalt/function/LuaFunction.java | 19 - .../function/LuaInterpretedFunction.java | 11 +- .../cobalt/function/LuaInterpreter.java | 150 ++-- .../java/org/squiddev/cobalt/lib/BaseLib.java | 25 +- .../squiddev/cobalt/lib/CoreLibraries.java | 2 +- .../org/squiddev/cobalt/lib/CoroutineLib.java | 5 +- .../org/squiddev/cobalt/lib/DebugLib.java | 13 +- .../cobalt/lib/system/SystemBaseLib.java | 2 +- .../java/org/squiddev/cobalt/AssertTests.java | 2 + .../squiddev/cobalt/CoroutineLoopTest.java | 2 +- .../org/squiddev/cobalt/CoroutineTest.java | 2 +- .../java/org/squiddev/cobalt/LuaSpecTest.java | 7 +- .../org/squiddev/cobalt/ScriptHelper.java | 14 +- .../cobalt/compiler/CompilerUnitTests.java | 1 + .../cobalt/lib/system/PackageLib.java | 7 +- .../org/squiddev/cobalt/vm/DataFactory.java | 3 +- .../squiddev/cobalt/vm/LuaOperationsTest.java | 139 +--- .../org/squiddev/cobalt/vm/LuaOperators.java | 5 +- .../org/squiddev/cobalt/vm/MetatableTest.java | 2 +- .../java/org/squiddev/cobalt/vm/TypeTest.java | 2 +- src/test/resources/assert/lex-context.lua | 2 +- src/test/resources/assert/lua5.1/calls.lua | 8 +- src/test/resources/assert/lua5.1/db.lua | 13 +- .../resources/bytecode-compiler/lua5.1/all.lc | Bin 5672 -> 5690 bytes .../resources/bytecode-compiler/lua5.1/api.lc | Bin 36097 -> 36616 bytes .../bytecode-compiler/lua5.1/attrib.lc | Bin 17976 -> 18016 bytes .../resources/bytecode-compiler/lua5.1/big.lc | Bin 17794 -> 16608 bytes .../bytecode-compiler/lua5.1/calls.lc | Bin 15887 -> 16437 bytes .../bytecode-compiler/lua5.1/checktable.lc | Bin 3946 -> 4106 bytes .../bytecode-compiler/lua5.1/closure.lc | Bin 20863 -> 21468 bytes .../bytecode-compiler/lua5.1/code.lc | Bin 6540 -> 7105 bytes .../bytecode-compiler/lua5.1/constructs.lc | Bin 12989 -> 13414 bytes .../resources/bytecode-compiler/lua5.1/db.lc | Bin 25715 -> 26178 bytes .../bytecode-compiler/lua5.1/errors.lc | Bin 12567 -> 12825 bytes .../bytecode-compiler/lua5.1/events.lc | Bin 21073 -> 21502 bytes .../bytecode-compiler/lua5.1/files.lc | Bin 18020 -> 18040 bytes .../resources/bytecode-compiler/lua5.1/gc.lc | Bin 14597 -> 14596 bytes .../bytecode-compiler/lua5.1/literals.lc | Bin 9212 -> 9185 bytes .../bytecode-compiler/lua5.1/locals.lc | Bin 6286 -> 5885 bytes .../bytecode-compiler/lua5.1/main.lc | Bin 5691 -> 5767 bytes .../bytecode-compiler/lua5.1/math.lc | Bin 12581 -> 12032 bytes .../bytecode-compiler/lua5.1/nextvar.lc | Bin 23925 -> 24311 bytes .../resources/bytecode-compiler/lua5.1/pm.lc | Bin 21273 -> 21413 bytes .../bytecode-compiler/lua5.1/sort.lc | Bin 4516 -> 4673 bytes .../bytecode-compiler/lua5.1/strings.lc | Bin 14035 -> 14068 bytes .../bytecode-compiler/lua5.1/vararg.lc | Bin 7001 -> 7249 bytes .../bytecode-compiler/lua5.1/verybig.lc | Bin 3269 -> 3378 bytes .../bytecode-compiler/regressions/bigattr.lc | Bin 206 -> 230 bytes .../regressions/comparators.lc | Bin 230 -> 254 bytes .../regressions/construct.lc | Bin 162 -> 186 bytes .../regressions/controlchars.l | Bin 0 -> 2173 bytes .../regressions/controlchars.lc | Bin 225 -> 229 bytes .../regressions/mathrandomseed.lc | Bin 159 -> 183 bytes .../bytecode-compiler/regressions/modulo.lc | Bin 220 -> 212 bytes .../regressions/multi-assign.lc | Bin 0 -> 253 bytes .../regressions/multi-assign.lua | 2 + .../bytecode-compiler/regressions/varargs.lc | Bin 2000 -> 2173 bytes src/test/resources/compare/coroutinelib.lua | 4 +- src/test/resources/compare/debuglib.lua | 2 +- src/test/resources/compare/debuglib.out | 18 +- src/test/resources/compare/errors/args.lua | 6 +- .../compare/errors/stringlibargs.out | 4 +- src/test/resources/compare/stringlib.out | 2 +- src/test/resources/compare/vm.lua | 3 +- src/test/resources/compare/vm.out | 20 +- src/test/resources/spec/base_spec.lua | 14 +- src/test/resources/spec/goto_spec.lua | 2 +- src/test/resources/spec/parser_spec.lua | 6 +- src/test/resources/spec/vararg_spec.lua | 4 +- 91 files changed, 1504 insertions(+), 1151 deletions(-) create mode 100644 src/main/java/cc/tweaked/cobalt/internal/LegacyEnv.java delete mode 100644 src/main/java/org/squiddev/cobalt/function/HasEnvironment.java create mode 100644 src/test/resources/bytecode-compiler/regressions/controlchars.l create mode 100644 src/test/resources/bytecode-compiler/regressions/multi-assign.lc create mode 100644 src/test/resources/bytecode-compiler/regressions/multi-assign.lua diff --git a/build.gradle.kts b/build.gradle.kts index 4ce74ba2..94199087 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -5,7 +5,7 @@ plugins { } group = "org.squiddev" -version = "0.7.3" +version = "0.8.0-SNAPSHOT" java { toolchain { @@ -120,7 +120,7 @@ publishing { pom { name.set("Cobalt") - description.set("A reentrant fork of LuaJ for Lua 5.1") + description.set("A reentrant fork of LuaJ for Lua 5.2") url.set("https://github.com/SquidDev/Cobalt") scm { diff --git a/src/main/java/cc/tweaked/cobalt/internal/LegacyEnv.java b/src/main/java/cc/tweaked/cobalt/internal/LegacyEnv.java new file mode 100644 index 00000000..91c16920 --- /dev/null +++ b/src/main/java/cc/tweaked/cobalt/internal/LegacyEnv.java @@ -0,0 +1,55 @@ +package cc.tweaked.cobalt.internal; + +import org.checkerframework.checker.nullness.qual.Nullable; +import org.squiddev.cobalt.Constants; +import org.squiddev.cobalt.LuaTable; +import org.squiddev.cobalt.LuaValue; +import org.squiddev.cobalt.Prototype; +import org.squiddev.cobalt.debug.Upvalue; +import org.squiddev.cobalt.function.LuaClosure; + +import java.util.Objects; + +/** + * Utilities for working with Lua 5.1-style {@code getfenv}/{@code setfenv}. + *

+ * These simply search for an {@link Constants#ENV _ENV} upvalue and set it. + */ +public final class LegacyEnv { + private LegacyEnv() { + } + + private static int findEnv(Prototype prototype) { + for (int i = 0; i < prototype.upvalues(); i++) { + if (Objects.equals(prototype.getUpvalueName(i), Constants.ENV)) return i; + } + + return -1; + } + + public static @Nullable LuaTable getEnv(LuaClosure closure) { + int index = findEnv(closure.getPrototype()); + return index >= 0 && closure.getUpvalue(index).getValue() instanceof LuaTable t ? t : null; + } + + public static @Nullable LuaTable getEnv(LuaValue value) { + return value instanceof LuaClosure c ? getEnv(c) : null; + } + + public static void setEnv(LuaClosure closure, LuaTable env) { + int index = findEnv(closure.getPrototype()); + if (index >= 0) { + // Slightly odd to create a new upvalue here, but ensures that it only affects this function. + closure.setUpvalue(index, new Upvalue(env)); + } + } + + public static boolean setEnv(LuaValue value, LuaTable env) { + if (!(value instanceof LuaClosure c)) return false; + + setEnv(c, env); + // We always return true on Lua closures, even if technically this won't do anything, as it ensures somewhat + // consistent behaviour. + return true; + } +} diff --git a/src/main/java/org/squiddev/cobalt/Constants.java b/src/main/java/org/squiddev/cobalt/Constants.java index a17608eb..eaea1fad 100644 --- a/src/main/java/org/squiddev/cobalt/Constants.java +++ b/src/main/java/org/squiddev/cobalt/Constants.java @@ -244,6 +244,11 @@ public class Constants { */ public static final LuaString LOADED = valueOf("_LOADED"); + /** + * LuaString constant with value "_ENV" for use as metatag + */ + public static final LuaString ENV = valueOf("_ENV"); + /** * Constant limiting metatag loop processing */ diff --git a/src/main/java/org/squiddev/cobalt/Lua.java b/src/main/java/org/squiddev/cobalt/Lua.java index 6eae3b06..74b6bbd7 100644 --- a/src/main/java/org/squiddev/cobalt/Lua.java +++ b/src/main/java/org/squiddev/cobalt/Lua.java @@ -36,13 +36,6 @@ public class Lua { */ public static final int LUA_MULTRET = -1; - /** - * Masks for new-style vararg - */ - public static final int VARARG_HASARG = 1; - public static final int VARARG_ISVARARG = 2; - public static final int VARARG_NEEDSARG = 4; - /*=========================================================================== We assume that instructions are unsigned numbers. All instructions have an opcode in the first 6 bits. @@ -50,6 +43,7 @@ public class Lua { `A' : 8 bits `B' : 9 bits `C' : 9 bits + 'Ax' : 26 bits ('A', 'B', and 'C' together) `Bx' : 18 bits (`B' and `C' together) `sBx' : signed Bx @@ -64,12 +58,14 @@ public class Lua { public static final int iABC = 0; public static final int iABx = 1; public static final int iAsBx = 2; + public static final int iAx = 3; // Size and position of opcode arguments. public static final int SIZE_C = 9; public static final int SIZE_B = 9; public static final int SIZE_Bx = SIZE_C + SIZE_B; public static final int SIZE_A = 8; + public static final int SIZE_Ax = SIZE_A + SIZE_B + SIZE_C; public static final int SIZE_OP = 6; @@ -78,20 +74,23 @@ public class Lua { public static final int POS_C = POS_A + SIZE_A; public static final int POS_B = POS_C + SIZE_C; public static final int POS_Bx = POS_C; + public static final int POS_Ax = POS_A; // Limits for opcode arguments. - public static final int MAX_OP = (1 << Lua.SIZE_OP) - 1; - public static final int MAXARG_A = (1 << Lua.SIZE_A) - 1; - public static final int MAXARG_B = (1 << Lua.SIZE_B) - 1; - public static final int MAXARG_C = (1 << Lua.SIZE_C) - 1; - public static final int MAXARG_Bx = (1 << Lua.SIZE_Bx) - 1; + public static final int MAX_OP = (1 << SIZE_OP) - 1; + public static final int MAXARG_A = (1 << SIZE_A) - 1; + public static final int MAXARG_B = (1 << SIZE_B) - 1; + public static final int MAXARG_C = (1 << SIZE_C) - 1; + public static final int MAXARG_Bx = (1 << SIZE_Bx) - 1; public static final int MAXARG_sBx = MAXARG_Bx >> 1; // sBx' is signed + public static final int MAXARG_Ax = (1 << SIZE_Ax) - 1; - public static final int MASK_OP = ((1 << Lua.SIZE_OP) - 1) << Lua.POS_OP; - public static final int MASK_A = ((1 << Lua.SIZE_A) - 1) << Lua.POS_A; - public static final int MASK_B = ((1 << Lua.SIZE_B) - 1) << Lua.POS_B; - public static final int MASK_C = ((1 << Lua.SIZE_C) - 1) << Lua.POS_C; - public static final int MASK_Bx = ((1 << Lua.SIZE_Bx) - 1) << Lua.POS_Bx; + public static final int MASK_OP = ((1 << SIZE_OP) - 1) << POS_OP; + public static final int MASK_A = ((1 << SIZE_A) - 1) << POS_A; + public static final int MASK_B = ((1 << SIZE_B) - 1) << POS_B; + public static final int MASK_C = ((1 << SIZE_C) - 1) << POS_C; + public static final int MASK_Bx = ((1 << SIZE_Bx) - 1) << POS_Bx; + public static final int MASK_Ax = ((1 << SIZE_Ax) - 1) << POS_Ax; // Utilities for reading instructions. @@ -119,6 +118,9 @@ public static int GETARG_sBx(int i) { return ((i >> POS_Bx) & MAXARG_Bx) - MAXARG_sBx; } + public static int GETARG_Ax(int i) { + return (i >> POS_Ax) & MAXARG_Ax; + } /** * This bit 1 means constant (0 means register) @@ -159,7 +161,7 @@ public static int RKASK(int x) { /** * Invalid register that fits in 8 bits */ - public static final int NO_REG = Lua.MAXARG_A; + public static final int NO_REG = MAXARG_A; /* ** R(x) - register @@ -168,57 +170,60 @@ public static int RKASK(int x) { */ public static final int OP_MOVE = 0; // A B R(A) := R(B) public static final int OP_LOADK = 1; // A Bx R(A) := Kst(Bx) - public static final int OP_LOADBOOL = 2; // A B C R(A) := (Bool)B; if (C) pc++ - public static final int OP_LOADNIL = 3; // A B R(A) := ... := R(B) := nil - public static final int OP_GETUPVAL = 4; // A B R(A) := UpValue[B] + public static final int OP_LOADKX = 2; // A R(A) := Kst(extra arg) + public static final int OP_LOADBOOL = 3; // A B C R(A) := (Bool)B; if (C) pc++ + public static final int OP_LOADNIL = 4; // A B R(A), R(A+1), ..., R(A+B) := nil + public static final int OP_GETUPVAL = 5; // A B R(A) := UpValue[B] + + public static final int OP_GETTABUP = 6; // A B C R(A) := UpValue[B][RK(C)] + public static final int OP_GETTABLE = 7; // A B C R(A) := R(B)[RK(C)] - public static final int OP_GETGLOBAL = 5; // A Bx R(A) := Gbl[Kst(Bx)] - public static final int OP_GETTABLE = 6; // A B C R(A) := R(B)[RK(C)] + public static final int OP_SETTABUP = 8; // A B C UpValue[A][RK(B)] := RK(C) + public static final int OP_SETUPVAL = 9; // A B UpValue[B] := R(A) + public static final int OP_SETTABLE = 10; // A B C R(A)[RK(B)] := RK(C) - public static final int OP_SETGLOBAL = 7; // A Bx Gbl[Kst(Bx)] := R(A) - public static final int OP_SETUPVAL = 8; // A B UpValue[B] := R(A) - public static final int OP_SETTABLE = 9; // A B C R(A)[RK(B)] := RK(C) + public static final int OP_NEWTABLE = 11; // A B C R(A) := {} (size = B,C) - public static final int OP_NEWTABLE = 10; // A B C R(A) := {} (size = B,C) + public static final int OP_SELF = 12; // A B C R(A+1) := R(B); R(A) := R(B)[RK(C)] - public static final int OP_SELF = 11; // A B C R(A+1) := R(B); R(A) := R(B)[RK(C)] + public static final int OP_ADD = 13; // A B C R(A) := RK(B) + RK(C) + public static final int OP_SUB = 14; // A B C R(A) := RK(B) - RK(C) + public static final int OP_MUL = 15; // A B C R(A) := RK(B) * RK(C) + public static final int OP_DIV = 16; // A B C R(A) := RK(B) / RK(C) + public static final int OP_MOD = 17; // A B C R(A) := RK(B) % RK(C) + public static final int OP_POW = 18; // A B C R(A) := RK(B) ^ RK(C) + public static final int OP_UNM = 19; // A B R(A) := -R(B) + public static final int OP_NOT = 20; // A B R(A) := not R(B) + public static final int OP_LEN = 21; // A B R(A) := length of R(B) - public static final int OP_ADD = 12; // A B C R(A) := RK(B) + RK(C) - public static final int OP_SUB = 13; // A B C R(A) := RK(B) - RK(C) - public static final int OP_MUL = 14; // A B C R(A) := RK(B) * RK(C) - public static final int OP_DIV = 15; // A B C R(A) := RK(B) / RK(C) - public static final int OP_MOD = 16; // A B C R(A) := RK(B) % RK(C) - public static final int OP_POW = 17; // A B C R(A) := RK(B) ^ RK(C) - public static final int OP_UNM = 18; // A B R(A) := -R(B) - public static final int OP_NOT = 19; // A B R(A) := not R(B) - public static final int OP_LEN = 20; // A B R(A) := length of R(B) + public static final int OP_CONCAT = 22; // A B C R(A) := R(B).. ... ..R(C) - public static final int OP_CONCAT = 21; // A B C R(A) := R(B).. ... ..R(C) + public static final int OP_JMP = 23; // A sBx pc+=sBx; if (A) close all upvalues >= R(A - 1) + public static final int OP_EQ = 24; // A B C if ((RK(B) == RK(C)) ~= A) then pc++ + public static final int OP_LT = 25; // A B C if ((RK(B) < RK(C)) ~= A) then pc++ + public static final int OP_LE = 26; // A B C if ((RK(B) <= RK(C)) ~= A) then pc++ - public static final int OP_JMP = 22; // sBx pc+=sBx + public static final int OP_TEST = 27; // A C if not (R(A) <=> C) then pc++ + public static final int OP_TESTSET = 28; // A B C if (R(B) <=> C) then R(A) := R(B) else pc++ - public static final int OP_EQ = 23; // A B C if ((RK(B) == RK(C)) ~= A) then pc++ - public static final int OP_LT = 24; // A B C if ((RK(B) < RK(C)) ~= A) then pc++ - public static final int OP_LE = 25; // A B C if ((RK(B) <= RK(C)) ~= A) then pc++ + public static final int OP_CALL = 29; // A B C R(A), ... ,R(A+C-2) := R(A)(R(A+1), ... ,R(A+B-1)) + public static final int OP_TAILCALL = 30; // A B C return R(A)(R(A+1), ... ,R(A+B-1)) + public static final int OP_RETURN = 31; // A B return R(A), ... ,R(A+B-2) (see note) - public static final int OP_TEST = 26; // A C if not (R(A) <=> C) then pc++ * - public static final int OP_TESTSET = 27; // A B C if (R(B) <=> C) then R(A) := R(B) else pc++ * + public static final int OP_FORLOOP = 32; // A sBx R(A)+=R(A+2; if R(A) =) R(A) - public static final int OP_CLOSURE = 36; // A Bx R(A) := closure(KPROTO[Bx], R(A), ... ,R(A+n)) + public static final int OP_VARARG = 38; // A B R(A), R(A+1), ..., R(A+B-2) = vararg - public static final int OP_VARARG = 37; // A B R(A), R(A+1), ..., R(A+B-1) = vararg - public static final int NUM_OPCODES = OP_VARARG + 1; + public static final int OP_EXTRAARG = 39; // Ax extra (larger) argument for previous opcode + public static final int NUM_OPCODES = OP_EXTRAARG + 1; /*=========================================================================== Notes: @@ -262,12 +267,13 @@ private static int opmode(int t, int a, int b, int c, int m) { public static final int[] opmodes = { opmode(0, 1, OpArgR, OpArgN, iABC), // OP_MOVE opmode(0, 1, OpArgK, OpArgN, iABx), // OP_LOADK + opmode(0, 1, OpArgN, OpArgN, iABx), // OP_LOADKX opmode(0, 1, OpArgU, OpArgU, iABC), // OP_LOADBOOL - opmode(0, 1, OpArgR, OpArgN, iABC), // OP_LOADNIL + opmode(0, 1, OpArgU, OpArgN, iABC), // OP_LOADNIL opmode(0, 1, OpArgU, OpArgN, iABC), // OP_GETUPVAL - opmode(0, 1, OpArgK, OpArgN, iABx), // OP_GETGLOBAL + opmode(0, 1, OpArgU, OpArgK, iABC), // OP_GETTABUP opmode(0, 1, OpArgR, OpArgK, iABC), // OP_GETTABLE - opmode(0, 0, OpArgK, OpArgN, iABx), // OP_SETGLOBAL + opmode(0, 0, OpArgK, OpArgK, iABC), // OP_SETTABUP opmode(0, 0, OpArgU, OpArgN, iABC), // OP_SETUPVAL opmode(0, 0, OpArgK, OpArgK, iABC), // OP_SETTABLE opmode(0, 1, OpArgU, OpArgU, iABC), // OP_NEWTABLE @@ -286,18 +292,19 @@ private static int opmode(int t, int a, int b, int c, int m) { opmode(1, 0, OpArgK, OpArgK, iABC), // OP_EQ opmode(1, 0, OpArgK, OpArgK, iABC), // OP_LT opmode(1, 0, OpArgK, OpArgK, iABC), // OP_LE - opmode(1, 1, OpArgR, OpArgU, iABC), // OP_TEST + opmode(1, 0, OpArgN, OpArgU, iABC), // OP_TEST opmode(1, 1, OpArgR, OpArgU, iABC), // OP_TESTSET opmode(0, 1, OpArgU, OpArgU, iABC), // OP_CALL opmode(0, 1, OpArgU, OpArgU, iABC), // OP_TAILCALL opmode(0, 0, OpArgU, OpArgN, iABC), // OP_RETURN opmode(0, 1, OpArgR, OpArgN, iAsBx), // OP_FORLOOP opmode(0, 1, OpArgR, OpArgN, iAsBx), // OP_FORPREP - opmode(1, 0, OpArgN, OpArgU, iABC), // OP_TFORLOOP + opmode(0, 0, OpArgN, OpArgU, iABC), // OP_TFORCALL + opmode(0, 1, OpArgR, OpArgN, iAsBx), // OP_TFORLOOP opmode(0, 0, OpArgU, OpArgU, iABC), // OP_SETLIST - opmode(0, 0, OpArgN, OpArgN, iABC), // OP_CLOSE opmode(0, 1, OpArgU, OpArgN, iABx), // OP_CLOSURE opmode(0, 1, OpArgU, OpArgN, iABC), // OP_VARARG + opmode(0, 0, OpArgU, OpArgU, iAx), // OP_EXTRAARG }; public static int getOpMode(int m) { @@ -328,12 +335,13 @@ public static boolean testTMode(int m) { private static final String[] opcodeNames = { "MOVE", "LOADK", + "LOADKX", "LOADBOOL", "LOADNIL", "GETUPVAL", - "GETGLOBAL", + "GETTABUP", "GETTABLE", - "SETGLOBAL", + "SETTABUP", "SETUPVAL", "SETTABLE", "NEWTABLE", @@ -359,11 +367,12 @@ public static boolean testTMode(int m) { "RETURN", "FORLOOP", "FORPREP", + "TFORCALL", "TFORLOOP", "SETLIST", - "CLOSE", "CLOSURE", "VARARG", + "EXTRAARG", }; public static String getOpName(int opcode) { diff --git a/src/main/java/org/squiddev/cobalt/LuaState.java b/src/main/java/org/squiddev/cobalt/LuaState.java index 8c875c1c..b8ef885f 100644 --- a/src/main/java/org/squiddev/cobalt/LuaState.java +++ b/src/main/java/org/squiddev/cobalt/LuaState.java @@ -89,6 +89,8 @@ public final class LuaState { */ private final LuaThread mainThread; + private final LuaTable globals = new LuaTable(); + /** * Report an internal VM error. */ @@ -106,7 +108,16 @@ private LuaState(Builder builder) { reportError = builder.reportError; bytecodeFormat = builder.bytecodeFormat; - mainThread = currentThread = new LuaThread(this, new LuaTable()); + mainThread = currentThread = new LuaThread(this); + } + + /** + * Get the global environment. + * + * @return The global environment. + */ + public LuaTable globals() { + return globals; } /** diff --git a/src/main/java/org/squiddev/cobalt/LuaThread.java b/src/main/java/org/squiddev/cobalt/LuaThread.java index 6aa0507f..7b97a706 100644 --- a/src/main/java/org/squiddev/cobalt/LuaThread.java +++ b/src/main/java/org/squiddev/cobalt/LuaThread.java @@ -96,11 +96,6 @@ public LuaValue getDisplayNameValue() { */ private Status status; - /** - * The environment this thread has. - */ - private LuaTable env; - /** * The function called when handling errors */ @@ -125,17 +120,14 @@ public LuaValue getDisplayNameValue() { * Constructor for main thread only * * @param state The current lua state - * @param env The thread's environment */ - public LuaThread(LuaState state, LuaTable env) { + public LuaThread(LuaState state) { super(Constants.TTHREAD); Objects.requireNonNull(state, "state cannot be null"); - Objects.requireNonNull(env, "env cannot be null"); status = Status.RUNNING; luaState = state; debugState = new DebugState(state); - this.env = env; function = null; } @@ -144,18 +136,15 @@ public LuaThread(LuaState state, LuaTable env) { * * @param state The current lua state * @param func The function to execute - * @param env The environment to apply to the thread */ - public LuaThread(LuaState state, LuaFunction func, LuaTable env) { + public LuaThread(LuaState state, LuaFunction func) { super(Constants.TTHREAD); Objects.requireNonNull(state, "state cannot be null"); Objects.requireNonNull(func, "func cannot be null"); - Objects.requireNonNull(env, "env cannot be null"); status = Status.INITIAL; luaState = state; debugState = new DebugState(state); - this.env = env; function = func; LuaThread current = state.getCurrentThread(); @@ -180,17 +169,6 @@ public LuaTable getMetatable(LuaState state) { return state.threadMetatable; } - @Override - public LuaTable getfenv() { - return env; - } - - @Override - public boolean setfenv(LuaTable env) { - this.env = env; - return true; - } - public Status getStatus() { return status; } diff --git a/src/main/java/org/squiddev/cobalt/LuaValue.java b/src/main/java/org/squiddev/cobalt/LuaValue.java index 018c95c2..779a054b 100644 --- a/src/main/java/org/squiddev/cobalt/LuaValue.java +++ b/src/main/java/org/squiddev/cobalt/LuaValue.java @@ -737,30 +737,6 @@ public void setMetatable(LuaState state, LuaTable metatable) throws LuaError { throw ErrorFactory.argError(this, "table"); } - /** - * Get the environment for an instance. - * - * @return {@link LuaValue} currently set as the instance's environent. - */ - public LuaValue getfenv() { - return NIL; - } - - /** - * Set the environment on an object. - *

- * However, any object can serve as an environment if it contains suitable metatag - * values to implement {@link OperationHelper#getTable(LuaState, LuaValue, LuaValue)} to provide the environment - * values. - * - * @param env {@link LuaValue} (typically a {@link LuaTable}) containing the environment. - * @return If the environment could be changed. - * @see CoreLibraries - */ - public boolean setfenv(LuaTable env) { - return false; - } - /** * Get particular metatag, or return {@link Constants#NIL} if it doesn't exist * diff --git a/src/main/java/org/squiddev/cobalt/Print.java b/src/main/java/org/squiddev/cobalt/Print.java index f481b382..eb087d74 100644 --- a/src/main/java/org/squiddev/cobalt/Print.java +++ b/src/main/java/org/squiddev/cobalt/Print.java @@ -32,7 +32,7 @@ /** * Debug helper class for pretty-printing Lua bytecode. *

- * This follows the implementation in {@code luac.c}/{@code print.c}. + * This follows the implementation in {@code luac.c}. * * @see Prototype * @see LuaClosure @@ -104,6 +104,7 @@ public static void printOpcode(StringBuilder out, Prototype f, int pc, boolean e int a = GETARG_A(i); int b = GETARG_B(i); int c = GETARG_C(i); + int ax = GETARG_Ax(i); int bx = GETARG_Bx(i); int sbx = GETARG_sBx(i); @@ -135,12 +136,11 @@ public static void printOpcode(StringBuilder out, Prototype f, int pc, boolean e } case iABx -> { out.append(a); - if (getBMode(o) == OpArgK) out.append(" ").append(MYK(INDEXK(bx))); + if (getBMode(o) == OpArgK) out.append(" ").append(MYK(bx)); if (getBMode(o) == OpArgU) out.append(" ").append(bx); } - case iAsBx -> { - out.append(a).append(" ").append(sbx); - } + case iAsBx -> out.append(a).append(" ").append(sbx); + case iAx -> out.append(a).append(" ").append(MYK(ax)); } switch (o) { @@ -150,16 +150,27 @@ public static void printOpcode(StringBuilder out, Prototype f, int pc, boolean e } case OP_GETUPVAL, OP_SETUPVAL -> { out.append("\t; "); - var upvalueName = f.getUpvalueName(b); - if (upvalueName != null) { - printString(out, upvalueName); - } else { - out.append("-"); + printUpvalueName(out, f, b); + } + case OP_GETTABUP -> { + out.append("\t; "); + printUpvalueName(out, f, b); + if (ISK(c)) { + out.append(" "); + printConstant(out, f, INDEXK(c)); } } - case OP_GETGLOBAL, OP_SETGLOBAL -> { + case OP_SETTABUP -> { out.append("\t; "); - printConstant(out, f, bx); + printUpvalueName(out, f, a); + if (ISK(b)) { + out.append(" "); + printConstant(out, f, INDEXK(b)); + } + if (ISK(c)) { + out.append(" "); + printConstant(out, f, INDEXK(c)); + } } case OP_GETTABLE, OP_SELF -> { if (ISK(c)) { @@ -183,7 +194,7 @@ public static void printOpcode(StringBuilder out, Prototype f, int pc, boolean e } } } - case OP_JMP, OP_FORLOOP, OP_FORPREP -> out.append("\t; to ").append(sbx + pc + 2); + case OP_JMP, OP_FORLOOP, OP_FORPREP, OP_TFORLOOP -> out.append("\t; to ").append(sbx + pc + 2); case OP_CLOSURE -> out.append("\t; ").append(id(f.children[bx])); case OP_SETLIST -> { if (c == 0) { @@ -192,12 +203,24 @@ public static void printOpcode(StringBuilder out, Prototype f, int pc, boolean e out.append("\t; ").append(c); } } - case OP_VARARG -> out.append("\t; is_vararg=").append(f.isVarArg); + case OP_EXTRAARG -> { + out.append("\t; "); + printConstant(out, f, ax); + } default -> { } } } + private static void printUpvalueName(StringBuilder out, Prototype f, int upvalue) { + var upvalueName = f.getUpvalueName(upvalue); + if (upvalueName != null) { + out.append(upvalueName); + } else { + out.append("-"); + } + } + private static void printHeader(StringBuilder out, Prototype f) { String s = String.valueOf(f.source); if (s.startsWith("@") || s.startsWith("=")) { @@ -210,7 +233,7 @@ private static void printHeader(StringBuilder out, Prototype f) { out.append("\n%").append(f.lineDefined == 0 ? "main" : "function") .append(" <").append(s).append(":").append(f.lineDefined).append(",").append(f.lastLineDefined).append("> (") .append(f.code.length).append(" instructions, ").append(f.code.length * 4).append(" bytes at ").append(id(f)).append(")\n"); - out.append(f.parameters).append(" param, ").append(f.maxStackSize).append(" slots, ").append(f.upvalueNames.length).append(" upvalues, "); + out.append(f.parameters).append(" param, ").append(f.maxStackSize).append(" slots, ").append(f.upvalues()).append(" upvalues, "); out.append(f.locals.length).append(" locals, ").append(f.constants.length).append(" constants, ").append(f.children.length).append(" functions\n"); } @@ -238,9 +261,15 @@ private static void printLocals(StringBuilder out, Prototype f) { } private static void printUpvalues(StringBuilder out, Prototype f) { - out.append("upvalues (").append(f.upvalues).append(") for ").append(id(f)).append(":\n"); - for (int i = 0, n = f.upvalues; i < n; i++) { - out.append("\t").append(i).append("\t").append(f.getUpvalueName(i)).append("\n"); + out.append("upvalues (").append(f.upvalues()).append(") for ").append(id(f)).append(":\n"); + for (int i = 0, n = f.upvalues(); i < n; i++) { + var upvalue = f.getUpvalue(i); + out + .append("\t").append(i) + .append("\t").append(upvalue.name()) + .append("\t").append(upvalue.fromLocal() ? '1' : '0') + .append("\t").append(upvalue.index()) + .append("\n"); } } diff --git a/src/main/java/org/squiddev/cobalt/Prototype.java b/src/main/java/org/squiddev/cobalt/Prototype.java index eecee25d..91c9c6d3 100644 --- a/src/main/java/org/squiddev/cobalt/Prototype.java +++ b/src/main/java/org/squiddev/cobalt/Prototype.java @@ -55,11 +55,12 @@ public final class Prototype { */ public final Prototype[] children; - public final int upvalues; + private final UpvalueInfo[] upvalues; + public final int lineDefined; public final int lastLineDefined; public final int parameters; - public final int isVarArg; + public final boolean isVarArg; public final int maxStackSize; /** @@ -74,12 +75,10 @@ public final class Prototype { */ public final LocalVariable[] locals; - public final LuaString[] upvalueNames; - public Prototype( LuaString source, LuaString shortSource, - LuaValue[] constants, int[] code, Prototype[] children, int parameters, int isVarArg, int maxStackSize, int upvalues, - int lineDefined, int lastLineDefined, int[] lineInfo, int[] columnInfo, LocalVariable[] locals, LuaString[] upvalueNames + LuaValue[] constants, int[] code, Prototype[] children, int parameters, boolean isVarArg, int maxStackSize, UpvalueInfo[] upvalues, + int lineDefined, int lastLineDefined, int[] lineInfo, int[] columnInfo, LocalVariable[] locals ) { this.source = source; this.shortSource = shortSource; @@ -97,7 +96,6 @@ public Prototype( this.lineInfo = lineInfo; this.columnInfo = columnInfo; this.locals = locals; - this.upvalueNames = upvalueNames; } public LuaString shortSource() { @@ -126,8 +124,16 @@ public String toString() { return null; // not found } + public int upvalues() { + return upvalues.length; + } + + public UpvalueInfo getUpvalue(int upvalue) { + return upvalues[upvalue]; + } + public @Nullable LuaString getUpvalueName(int index) { - return index >= 0 && index < upvalueNames.length ? upvalueNames[index] : null; + return index >= 0 && index < upvalues.length ? upvalues[index].name() : null; } /** @@ -149,4 +155,17 @@ public int lineAt(int pc) { public int columnAt(int pc) { return pc >= 0 && pc < columnInfo.length ? columnInfo[pc] : -1; } + + /** + * Information about an upvalue. + * + * @param name The name of this upvalue. + * @param fromLocal Whether this upvalue comes from a local (if true) or an upvalue (if false). + * @param byteIndex The short index of this upvalue. Use {@link #index()} when an int index is needed. + */ + public record UpvalueInfo(@Nullable LuaString name, boolean fromLocal, byte byteIndex) { + public int index() { + return byteIndex & 0xFF; + } + } } diff --git a/src/main/java/org/squiddev/cobalt/compiler/BytecodeDumper.java b/src/main/java/org/squiddev/cobalt/compiler/BytecodeDumper.java index 9542e2a1..54203da7 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/BytecodeDumper.java +++ b/src/main/java/org/squiddev/cobalt/compiler/BytecodeDumper.java @@ -24,6 +24,7 @@ */ package org.squiddev.cobalt.compiler; +import org.checkerframework.checker.nullness.qual.Nullable; import org.squiddev.cobalt.Constants; import org.squiddev.cobalt.LuaString; import org.squiddev.cobalt.LuaValue; @@ -34,8 +35,7 @@ import java.io.IOException; import java.io.OutputStream; -import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.LUAC_FORMAT; -import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.LUAC_VERSION; +import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.*; class BytecodeDumper { private static final boolean IS_LITTLE_ENDIAN = true; @@ -48,8 +48,8 @@ class BytecodeDumper { private final DataOutputStream writer; private final boolean strip; - public BytecodeDumper(OutputStream w, boolean strip) { - this.writer = new DataOutputStream(w); + private BytecodeDumper(OutputStream w, boolean strip) { + writer = new DataOutputStream(w); this.strip = strip; } @@ -68,6 +68,14 @@ private void dumpInt(int x) throws IOException { } } + private void dumpNullableString(@Nullable LuaString s) throws IOException { + if (s == null) { + dumpInt(0); + } else { + dumpString(s); + } + } + private void dumpString(LuaString s) throws IOException { final int len = s.length(); dumpInt(len + 1); @@ -117,21 +125,22 @@ private void dumpConstants(final Prototype f) throws IOException { default -> throw new IllegalArgumentException("bad type for " + o); } } - n = f.children.length; - dumpInt(n); - for (i = 0; i < n; i++) { - dumpFunction(f.children[i], f.source); - } } private void dumpDebug(final Prototype f) throws IOException { + if (f.source == null || strip) { + dumpInt(0); + } else { + dumpString(f.source); + } + int i, n; - n = (strip) ? 0 : f.lineInfo.length; + n = strip ? 0 : f.lineInfo.length; dumpInt(n); for (i = 0; i < n; i++) { dumpInt(f.lineInfo[i]); } - n = (strip) ? 0 : f.locals.length; + n = strip ? 0 : f.locals.length; dumpInt(n); for (i = 0; i < n; i++) { LocalVariable lvi = f.locals[i]; @@ -139,30 +148,42 @@ private void dumpDebug(final Prototype f) throws IOException { dumpInt(lvi.startpc); dumpInt(lvi.endpc); } - n = (strip) ? 0 : f.upvalueNames.length; + n = strip ? 0 : f.upvalues(); dumpInt(n); - for (i = 0; i < n; i++) { - dumpString(f.upvalueNames[i]); - } + for (i = 0; i < n; i++) dumpNullableString(f.getUpvalueName(i)); } - private void dumpFunction(final Prototype f, final LuaString string) throws IOException { - if (f.source == null || f.source.equals(string) || strip) { - dumpInt(0); - } else { - dumpString(f.source); - } + private void dumpFunction(final Prototype f) throws IOException { dumpInt(f.lineDefined); dumpInt(f.lastLineDefined); - dumpChar(f.upvalues); dumpChar(f.parameters); - dumpChar(f.isVarArg); + dumpChar(f.isVarArg ? 1 : 0); dumpChar(f.maxStackSize); dumpCode(f); dumpConstants(f); + dumpFunctions(f); + dumpUpvalues(f); dumpDebug(f); } + private void dumpUpvalues(Prototype f) throws IOException { + int n = f.upvalues(); + dumpInt(n); + for (int i = 0; i < n; i++) { + var info = f.getUpvalue(i); + writer.writeBoolean(info.fromLocal()); + writer.writeByte(info.byteIndex()); + } + } + + private void dumpFunctions(Prototype f) throws IOException { + int n = f.children.length; + dumpInt(n); + for (int i = 0; i < n; i++) { + dumpFunction(f.children[i]); + } + } + private void dumpHeader() throws IOException { writer.write(LuaBytecodeFormat.LUA_SIGNATURE); writer.write(LUAC_VERSION); @@ -173,14 +194,20 @@ private void dumpHeader() throws IOException { writer.write(SIZEOF_INSTRUCTION); writer.write(SIZEOF_LUA_NUMBER); writer.write(NUMBER_FORMAT); + writer.write(LUAC_TAIL); } - /* - * Dump Lua function as precompiled chunk + /** + * Dump Lua functions as a precompiled chunk. + * + * @param f The function to dump. + * @param w The output stream to write to. + * @param strip Whether to strip debug informtion. + * @throws IOException If writing to the underlying stream failed. */ public static void dump(Prototype f, OutputStream w, boolean strip) throws IOException { BytecodeDumper D = new BytecodeDumper(w, strip); D.dumpHeader(); - D.dumpFunction(f, null); + D.dumpFunction(f); } } diff --git a/src/main/java/org/squiddev/cobalt/compiler/BytecodeLoader.java b/src/main/java/org/squiddev/cobalt/compiler/BytecodeLoader.java index 72d167bf..6fd172f7 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/BytecodeLoader.java +++ b/src/main/java/org/squiddev/cobalt/compiler/BytecodeLoader.java @@ -29,8 +29,7 @@ import org.squiddev.cobalt.unwind.AutoUnwind; import static org.squiddev.cobalt.Constants.*; -import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.LUAC_VERSION; -import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.LUA_SIGNATURE; +import static org.squiddev.cobalt.compiler.LuaBytecodeFormat.*; /** * Parser for bytecode @@ -76,7 +75,7 @@ final class BytecodeLoader { private static final LuaValue[] NOVALUES = {}; private static final Prototype[] NOPROTOS = {}; private static final LocalVariable[] NOLOCVARS = {}; - private static final LuaString[] NOSTRVALUES = {}; + private static final Prototype.UpvalueInfo[] NOUPVALUES = {}; private static final int[] NOINTS = {}; /** @@ -126,15 +125,11 @@ private int loadInt() throws CompileException, LuaError, UnwindThrowable { */ private int[] loadIntArray() throws CompileException, LuaError, UnwindThrowable { int n = loadInt(); - if (n == 0) { - return NOINTS; - } + if (n == 0) return NOINTS; // read all data at once int m = n << 2; - if (buf.length < m) { - buf = new byte[m]; - } + if (buf.length < m) buf = new byte[m]; readFully(buf, 0, m); int[] array = new int[n]; for (int i = 0, j = 0; i < n; ++i, j += 4) { @@ -169,10 +164,9 @@ private long loadInt64() throws CompileException, LuaError, UnwindThrowable { * @return the {@link LuaString} value laoded. */ private LuaString loadString() throws CompileException, LuaError, UnwindThrowable { - int size = this.luacSizeofSizeT == 8 ? (int) loadInt64() : loadInt(); - if (size == 0) { - return null; - } + int size = luacSizeofSizeT == 8 ? (int) loadInt64() : loadInt(); + if (size == 0) return null; + byte[] bytes = new byte[size]; readFully(bytes, 0, size); return LuaString.valueOf(bytes, 0, bytes.length - 1); @@ -185,22 +179,6 @@ private LuaString loadString() throws CompileException, LuaError, UnwindThrowabl * @return {@link LuaInteger} or {@link LuaDouble} whose value corresponds to the bits provided. */ public static LuaValue longBitsToLuaNumber(long bits) { - if ((bits & ((1L << 63) - 1)) == 0L) { - return Constants.ZERO; - } - - int e = (int) ((bits >> 52) & 0x7ffL) - 1023; - - if (e >= 0 && e < 31) { - long f = bits & 0xFFFFFFFFFFFFFL; - int shift = 52 - e; - long intPrecMask = (1L << shift) - 1; - if ((f & intPrecMask) == 0) { - int intValue = (int) (f >> shift) | (1 << e); - return LuaInteger.valueOf(((bits >> 63) != 0) ? -intValue : intValue); - } - } - return ValueFactory.valueOf(Double.longBitsToDouble(bits)); } @@ -211,11 +189,7 @@ public static LuaValue longBitsToLuaNumber(long bits) { * @throws CompileException, LuaError, UnwindThrowable if an i/o exception occurs */ private LuaValue loadNumber() throws CompileException, LuaError, UnwindThrowable { - if (luacNumberFormat == NUMBER_FORMAT_INTS_ONLY) { - return LuaInteger.valueOf(loadInt()); - } else { - return longBitsToLuaNumber(loadInt64()); - } + return longBitsToLuaNumber(loadInt64()); } /** @@ -229,8 +203,7 @@ private LuaValue[] loadConstants() throws CompileException, LuaError, UnwindThro for (int i = 0; i < n; i++) { values[i] = switch (readByte()) { case TNIL -> Constants.NIL; - case TBOOLEAN -> (0 != readUnsignedByte() ? Constants.TRUE : Constants.FALSE); - case TINT -> LuaInteger.valueOf(loadInt()); + case TBOOLEAN -> readByte() != 0 ? Constants.TRUE : Constants.FALSE; case TNUMBER -> loadNumber(); case TSTRING -> loadString(); default -> throw new IllegalStateException("bad constant"); @@ -239,11 +212,11 @@ private LuaValue[] loadConstants() throws CompileException, LuaError, UnwindThro return values; } - private Prototype[] loadChildren(LuaString source) throws CompileException, LuaError, UnwindThrowable { + private Prototype[] loadChildren() throws CompileException, LuaError, UnwindThrowable { int n = loadInt(); Prototype[] protos = n > 0 ? new Prototype[n] : NOPROTOS; for (int i = 0; i < n; i++) { - protos[i] = loadFunction(source); + protos[i] = loadFunction(); } return protos; } @@ -252,51 +225,64 @@ private LocalVariable[] loadLocals() throws CompileException, LuaError, UnwindTh int n = loadInt(); LocalVariable[] locals = n > 0 ? new LocalVariable[n] : NOLOCVARS; for (int i = 0; i < n; i++) { - LuaString varname = loadString(); + LuaString varName = loadString(); int startpc = loadInt(); int endpc = loadInt(); - locals[i] = new LocalVariable(varname, startpc, endpc); + locals[i] = new LocalVariable(varName, startpc, endpc); } return locals; } - private LuaString[] loadUpvalueNames() throws CompileException, LuaError, UnwindThrowable { + private void loadUpvaluesNames(Prototype.UpvalueInfo[] upvalues) throws CompileException, LuaError, UnwindThrowable { int n = loadInt(); - LuaString[] upvalueNames = n > 0 ? new LuaString[n] : NOSTRVALUES; - for (int i = 0; i < n; i++) upvalueNames[i] = loadString(); - return upvalueNames; + for (int i = 0; i < n; i++) { + var upvalue = upvalues[i]; + var name = loadString(); + upvalues[i] = new Prototype.UpvalueInfo(name, upvalue.fromLocal(), upvalue.byteIndex()); + } + } + + private Prototype.UpvalueInfo[] loadUpvalues() throws CompileException, LuaError, UnwindThrowable { + int n = loadInt(); + var upvalues = n > 0 ? new Prototype.UpvalueInfo[n] : NOUPVALUES; + for (int i = 0; i < n; i++) { + var inStack = readByte() != 0; + var slot = readByte(); + upvalues[i] = new Prototype.UpvalueInfo(null, inStack, slot); + } + return upvalues; } /** * Load a function prototype from the input stream * - * @param givenSource name of the source * @return {@link Prototype} instance that was loaded * @throws CompileException, LuaError, UnwindThrowable On stream read errors */ - public Prototype loadFunction(LuaString givenSource) throws CompileException, LuaError, UnwindThrowable { - LuaString source = loadString(); - if (source == null) source = givenSource; - - LuaString shortSource = LoadState.getShortName(source); - + public Prototype loadFunction() throws CompileException, LuaError, UnwindThrowable { int lineDefined = loadInt(); int lastLineDefined = loadInt(); - int nups = readUnsignedByte(); - int numparams = readUnsignedByte(); - int is_vararg = readUnsignedByte(); - int maxstacksize = readUnsignedByte(); + int numParams = readUnsignedByte(); + boolean isVarArg = readByte() != 0; + int maxStackSize = readUnsignedByte(); + int[] code = loadIntArray(); LuaValue[] constants = loadConstants(); - Prototype[] children = loadChildren(source); + Prototype[] children = loadChildren(); + Prototype.UpvalueInfo[] upvalues = loadUpvalues(); + + // See LoadDebug + var source = loadString(); + if (source == null) source = LuaString.valueOf("=?"); + int[] lineInfo = loadIntArray(); LocalVariable[] locals = loadLocals(); - LuaString[] upvalueNames = loadUpvalueNames(); + loadUpvaluesNames(upvalues); return new Prototype( - source, shortSource, - constants, code, children, numparams, is_vararg, maxstacksize, nups, - lineDefined, lastLineDefined, lineInfo, NOINTS, locals, upvalueNames + source, LoadState.getShortName(source), + constants, code, children, numParams, isVarArg, maxStackSize, upvalues, + lineDefined, lastLineDefined, lineInfo, NOINTS, locals ); } @@ -315,24 +301,24 @@ public void checkSignature() throws CompileException, LuaError, UnwindThrowable */ public void loadHeader() throws CompileException, LuaError, UnwindThrowable { int luacVersion = readByte(); - if (luacVersion != LUAC_VERSION) throw new CompileException("unsupported luac version"); + if (luacVersion != LUAC_VERSION) throw new CompileException("version mismatch"); int luacFormat = readByte(); - luacLittleEndian = (0 != readByte()); + if (luacFormat != LUAC_FORMAT) throw new CompileException("incompatible"); + + luacLittleEndian = readByte() != 0; int luacSizeofInt = readByte(); luacSizeofSizeT = readByte(); int luacSizeofInstruction = readByte(); int luacSizeofLuaNumber = readByte(); - luacNumberFormat = readByte(); - - // check format - switch (luacNumberFormat) { - case NUMBER_FORMAT_FLOATS_OR_DOUBLES: - case NUMBER_FORMAT_INTS_ONLY: - case NUMBER_FORMAT_NUM_PATCH_INT32: - break; - default: - throw new CompileException("unsupported int size"); + int luacNumberFormat = readByte(); + + if (luacSizeofInt != 4 || luacSizeofInstruction != 4 || luacSizeofLuaNumber != 8 || luacNumberFormat != 0) { + throw new CompileException("incompatible"); + } + + for (byte b : LUAC_TAIL) { + if (readByte() != b) throw new CompileException("incompatible"); } } } diff --git a/src/main/java/org/squiddev/cobalt/compiler/ExpKind.java b/src/main/java/org/squiddev/cobalt/compiler/ExpKind.java index 92c20f77..b5b9f6a5 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/ExpKind.java +++ b/src/main/java/org/squiddev/cobalt/compiler/ExpKind.java @@ -31,6 +31,11 @@ enum ExpKind { */ VKNUM, + /** + * info = localresult + */ + VNONRELOC, + /** * info = local register */ @@ -42,12 +47,7 @@ enum ExpKind { VUPVAL, /** - * info = index of table, aux = index of global name in `k`. - */ - VGLOBAL, - - /** - * info = table register, aux = index register (or `k`) + * info = table register/upvalue ("t"), aux = index register ("idx") */ VINDEXED, @@ -61,11 +61,6 @@ enum ExpKind { */ VRELOCABLE, - /** - * info = result register - */ - VNONRELOC, - /** * info = instruction pc */ @@ -80,6 +75,10 @@ boolean hasMultiRet() { return this == VCALL || this == VVARARG; } + boolean isInRegister() { + return this == VNONRELOC || this == VLOCAL; + } + boolean isVar() { return VLOCAL.ordinal() <= ordinal() && ordinal() <= VINDEXED.ordinal(); } diff --git a/src/main/java/org/squiddev/cobalt/compiler/FuncState.java b/src/main/java/org/squiddev/cobalt/compiler/FuncState.java index aebb30ff..c2e22a44 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/FuncState.java +++ b/src/main/java/org/squiddev/cobalt/compiler/FuncState.java @@ -45,24 +45,13 @@ * in PUC Lua). */ final class FuncState { - static class UpvalueDesc { - final LuaString name; - final ExpKind kind; - final short info; - - UpvalueDesc(LuaString name, ExpKind kind, short info) { - this.name = name; - this.kind = kind; - this.info = info; - } - } - static class BlockCnt { BlockCnt previous; /* chain */ - IntPtr breaklist = new IntPtr(); /* list of jumps out of this loop */ - short nactvar; /* # active locals outside the breakable structure */ + short firstLabel; /* Index of first label in this block. */ + short firstGoto; /* Index of first pending goto in this block. */ + short activeVariableCount; /* active locals outside the breakable structure */ boolean upval; /* true if some variable in the block is an upvalue */ - boolean isbreakable; /* true if `block' is a loop */ + boolean isLoop; /* true if `block' is a loop */ } final FuncState prev; /* enclosing function */ @@ -72,7 +61,7 @@ static class BlockCnt { private final Map constantLookup = new HashMap<>(); /* table to find (and reuse) elements in `k' */ final List locals = new ArrayList<>(0); final List children = new ArrayList<>(0); - final List upvalues = new ArrayList<>(0); /* upvalues */ + final List upvalues = new ArrayList<>(0); /* upvalues */ int pc; /* next position to code (equivalent to `ncode') */ int[] code; @@ -82,81 +71,106 @@ static class BlockCnt { int lineDefined; int lastLineDefined; int numParams; - int varargFlags; + boolean isVararg; int maxStackSize = 2; BlockCnt block; /* chain of current blocks */ - int lastTarget = -1; /* `pc' of last `jump target' */ + int lastTarget = 0; /* `pc' of last `jump target' */ final IntPtr jpc = new IntPtr(NO_JUMP); /* list of pending jumps to `pc' */ int freeReg; /* first free register */ short activeVariableCount; /* number of active local variables */ - short[] activeVariables = new short[LUAI_MAXVARS]; /* declared-variable stack */ - FuncState(Lex lexer, FuncState prev) { + final int firstLocal; + + FuncState(Lex lexer, FuncState prev, int firstLocal) { this.lexer = lexer; this.prev = prev; + this.firstLocal = firstLocal; } Prototype toPrototype() { - int i = 0; - LuaString[] upvalueNames = new LuaString[upvalues.size()]; - for (FuncState.UpvalueDesc upvalue : upvalues) upvalueNames[i++] = upvalue.name; - return new Prototype( lexer.source, lexer.shortSource, // Code constants.toArray(new LuaValue[0]), LuaC.realloc(code, pc), children.toArray(new Prototype[0]), - numParams, varargFlags, maxStackSize, upvalues.size(), + numParams, isVararg, maxStackSize, upvalues.toArray(Prototype.UpvalueInfo[]::new), // Debug information lineDefined, lastLineDefined, LuaC.realloc(lineInfo, pc), LuaC.realloc(columnInfo, pc), - locals.toArray(new LocalVariable[0]), upvalueNames + locals.toArray(new LocalVariable[0]) ); } - int codeAsBxAt(int o, int A, int sBx, long position) throws CompileException { - return codeABxAt(o, A, sBx + MAXARG_sBx, position); - } - - int codeAsBx(int o, int A, int sBx) throws CompileException { - return codeABx(o, A, sBx + MAXARG_sBx); - } void setMultiRet(ExpDesc e) throws CompileException { setReturns(e, LUA_MULTRET); } - LocalVariable getLocal(int i) { - return locals.get(activeVariables[i]); - } - // ============================================================= // from lcode.c // ============================================================= void nil(int from, int n) throws CompileException { + int to = from + n - 1; if (pc > lastTarget) { /* no jumps to current position? */ - if (pc == 0) { /* function start? */ - if (from >= activeVariableCount) { - return; /* positions are already clean */ - } - } else { - int previous = code[pc - 1]; - if (GET_OPCODE(previous) == OP_LOADNIL) { - int pfrom = GETARG_A(previous); - int pto = GETARG_B(previous); - if (pfrom <= from && from <= pto + 1) { /* can connect both? */ - if (from + n - 1 > pto) code[pc - 1] = SETARG_B(previous, from + n - 1); - return; - } + int previous = code[pc - 1]; + if (GET_OPCODE(previous) == OP_LOADNIL) { + int pFrom = GETARG_A(previous); + int pTo = pFrom + GETARG_B(previous); + if ((pFrom <= from && from <= pTo + 1) || (from <= pFrom && pFrom <= to + 1)) { /* can connect both? */ + if (pFrom <= from) from = pFrom; + if (pTo > to) to = pTo; + + previous = SETARG_A(previous, from); + previous = SETARG_B(previous, to - from); + code[pc - 1] = previous; + return; } } } - /* else no optimization */ - codeABC(OP_LOADNIL, from, from + n - 1, 0); + + codeABC(OP_LOADNIL, from, n - 1, 0); // else no optimization + } + + private int getJump(int pc) { + int offset = GETARG_sBx(code[pc]); + if (offset == NO_JUMP) { // point to itself represents end of list + return NO_JUMP; // end of list + } else { // turn offset into absolute position + return pc + 1 + offset; + } } + private void fixJump(int pc, int dest) throws CompileException { + int offset = dest - (pc + 1); + assert dest != NO_JUMP; + if (Math.abs(offset) > MAXARG_sBx) throw lexer.syntaxError("control structure too long"); + code[pc] = SETARG_sBx(code[pc], offset); + } + + /** + * Concatenate jump list {@code l2} into jump-list {@code l1} + */ + void concat(IntPtr l1, int l2) throws CompileException { + if (l2 == NO_JUMP) return; + if (l1.value == NO_JUMP) { + l1.value = l2; + } else { + int list = l1.value; + // find last element + int next; + while ((next = getJump(list)) != NO_JUMP) list = next; + fixJump(list, l2); + } + } + + /** + * Create a jump instruction and return its position, so it can e fixed up later with {@link #fixJump(int, int)}. + *

+ * If there are jumps to this position (kept in {@link #jpc}) link them together so that + * {@link #patchListAux(int, int, int, int)} will fix them to their final destination. + */ int jump() throws CompileException { int jpc = this.jpc.value; /* save list of jumps to here */ this.jpc.value = NO_JUMP; @@ -165,21 +179,21 @@ int jump() throws CompileException { return j.value; } + /** + * Code a {@link Lua#OP_RETURN} instruction. + */ void ret(int first, int nret) throws CompileException { codeABC(OP_RETURN, first, nret + 1, 0); } + /** + * Code a conditional jump namely {@link Lua#OP_TEST}/{@link Lua#OP_TESTSET}. + */ private int condJump(int op, int A, int B, int C, long position) throws CompileException { codeABCAt(op, A, B, C, position); return jump(); } - private void fixJump(int pc, int dest) throws CompileException { - int offset = dest - (pc + 1); - assert dest != NO_JUMP; - if (Math.abs(offset) > MAXARG_sBx) throw lexer.syntaxError("control structure too long"); - code[pc] = SETARG_sBx(code[pc], offset); - } /* * Returns current `pc' and marks it as a jump target (to avoid wrong @@ -190,15 +204,10 @@ int getLabel() { return pc; } - private int getJump(int pc) { - int offset = GETARG_sBx(code[pc]); - if (offset == NO_JUMP) { // point to itself represents end of list - return NO_JUMP; // end of list - } else { // turn offset into absolute position - return pc + 1 + offset; - } - } - + /** + * Returns the position of the instruction "controlling" a given jump (that is, its condition), or the jump itself + * if it is unconditional. + */ private int getJumpControl(int pc) { if (pc >= 1 && testTMode(GET_OPCODE(code[pc - 1]))) { return pc - 1; @@ -207,17 +216,14 @@ private int getJumpControl(int pc) { } } - /* - * Check whether list has any jump that do not produce a value - * (or produce an inverted value). - */ - private boolean needValue(int list) { - for (; list != NO_JUMP; list = getJump(list)) { - if (GET_OPCODE(code[getJumpControl(list)]) != OP_TESTSET) return true; - } - return false; // not found - } + /** + * Patch destination register for a {@link Lua#OP_TESTSET} instruction. + *

+ * If instruction in position 'node' is not a {@link Lua#OP_TESTSET}, return {@code false} Otherwise, if 'reg' + * is not {@link Lua#NO_REG}, set it as the destination register. Otherwise, change instruction to a simple + * {@link Lua#OP_TEST} (produces no register value). + */ private boolean patchTestReg(int node, int reg) { int jumpControlPc = getJumpControl(node); int op = code[jumpControlPc]; @@ -234,12 +240,21 @@ private boolean patchTestReg(int node, int reg) { return true; } + /** + * Traverse a list of tests, ensuring no one produces a value. + * + * @param list The head of the jump list. + */ private void removeValues(int list) { for (; list != NO_JUMP; list = getJump(list)) { patchTestReg(list, NO_REG); } } + /** + * Traverse a list of tests, patching their destination address and registers: tests producing values jump to + * 'vtarget' (and put their values in 'reg'), other tests jump to 'dtarget'. + */ private void patchListAux(int list, int vtarget, int reg, int dtarget) throws CompileException { while (list != NO_JUMP) { int next = getJump(list); @@ -252,38 +267,118 @@ private void patchListAux(int list, int vtarget, int reg, int dtarget) throws Co } } + /** + * Ensure all pending jumps to current position are fixed (jumping to current position with no values) and reset + * list of pending jumps + */ private void dischargeJumpPc() throws CompileException { patchListAux(jpc.value, pc, NO_REG, pc); jpc.value = NO_JUMP; } + /** + * Add elements in 'list' to list of pending jumps to "here" (current position) + */ + void patchToHere(int list) throws CompileException { + getLabel(); // Mark here as a target + concat(jpc, list); + } + + /** + * Path all jumps in 'list' to jump to 'target'. (The assert means that we cannot fix a jump to a forward address + * because we only know addresses once code is generated.) + */ void patchList(int list, int target) throws CompileException { if (target == pc) { patchToHere(list); } else { - _assert(target < pc); + assert target < pc; patchListAux(list, target, NO_REG, target); } } - void patchToHere(int list) throws CompileException { - getLabel(); - concat(jpc, list); + /** + * Path all jumps in 'list' to close upvalues up to given 'level' (The assertion checks that jumps either were + * closing nothing or were closing higher levels, from inner blocks.) + */ + void patchClose(int list, int level) { + level++; /* argument is +1 to reserve 0 as non-op */ + for (; list != NO_JUMP; list = getJump(list)) { + assert GET_OPCODE(code[list]) == OP_JMP && (GETARG_A(code[list]) == 0 || GETARG_A(code[list]) >= level); + code[list] = SETARG_A(code[list], level); + } } - void concat(IntPtr l1, int l2) throws CompileException { - if (l2 == NO_JUMP) return; - if (l1.value == NO_JUMP) { - l1.value = l2; + /** + * Emit the instruction at a specific location. + * + * @param instruction The bytecode of the instruction. + * @param position The packed position of this instruction. + * @return The position of this instruction. + */ + private int code(int instruction, long position) throws CompileException { + assert position > 0; + + dischargeJumpPc(); /* `pc' will change */ + + // put new instruction in code array + if (code == null || pc + 1 > code.length) code = LuaC.realloc(code, pc * 2 + 1); + code[pc] = instruction; + + // save corresponding line information + if (lineInfo == null || pc + 1 > lineInfo.length) { + lineInfo = LuaC.realloc(lineInfo, pc * 2 + 1); + columnInfo = LuaC.realloc(columnInfo, pc * 2 + 1); + } + lineInfo[pc] = Lex.unpackLine(position); + columnInfo[pc] = Lex.unpackColumn(position); + + return pc++; + } + + int codeABCAt(int o, int a, int b, int c, long position) throws CompileException { + assert getOpMode(o) == iABC; + assert getBMode(o) != OpArgN || b == 0; + assert getCMode(o) != OpArgN || c == 0; + return code(CREATE_ABC(o, a, b, c), position); + } + + int codeABC(int o, int a, int b, int c) throws CompileException { + return codeABCAt(o, a, b, c, lexer.lastPosition()); + } + + int codeABxAt(int o, int a, int bc, long position) throws CompileException { + assert getOpMode(o) == iABx || getOpMode(o) == iAsBx; + assert getCMode(o) == OpArgN; + assert position > 0; + return code(CREATE_ABx(o, a, bc), position); + } + + int codeABx(int o, int a, int bc) throws CompileException { + return codeABxAt(o, a, bc, lexer.lastPosition()); + } + + int codeAsBxAt(int o, int A, int sBx, long position) throws CompileException { + return codeABxAt(o, A, sBx + MAXARG_sBx, position); + } + + int codeAsBx(int o, int A, int sBx) throws CompileException { + return codeABx(o, A, sBx + MAXARG_sBx); + } + + int codeK(int reg, int k) throws CompileException { + if (k <= MAXARG_Bx) { + return codeABx(OP_LOADK, reg, k); } else { - int list = l1.value; - // find last element - int next; - while ((next = getJump(list)) != NO_JUMP) list = next; - fixJump(list, l2); + int p = codeABx(OP_LOADKX, reg, 0); + code(CREATE_Ax(OP_EXTRAARG, k), lexer.lastPosition()); + return p; } } + /** + * Check register stack level, keeping track of its maximum size. + */ void checkStack(int n) throws CompileException { int newStack = freeReg + n; if (newStack > maxStackSize) { @@ -292,22 +387,46 @@ void checkStack(int n) throws CompileException { } } + /** + * Reserve {@code n} registers. + */ void reserveRegs(int n) throws CompileException { checkStack(n); freeReg += n; } - private void freeReg(int reg) throws CompileException { + /** + * Free the given register (this may be a constant index, in which case this is a no-op). + */ + private void freeReg(int reg) { if (!ISK(reg) && reg >= activeVariableCount) { freeReg--; - _assert(reg == freeReg); + assert reg == freeReg; } } - private void freeExp(ExpDesc e) throws CompileException { + /** + * Free the register used by expression {@code e}. + */ + private void freeExp(ExpDesc e) { if (e.kind == ExpKind.VNONRELOC) freeReg(e.info); } + /** + * Free registers used by both expressions, in the correct order. + */ + private void freeExp(ExpDesc e1, ExpDesc e2) { + int r1 = e1.kind == ExpKind.VNONRELOC ? e1.info : -1; + int r2 = e2.kind == ExpKind.VNONRELOC ? e2.info : -1; + if (r1 > r2) { + freeReg(r1); + freeReg(r2); + } else { + freeReg(r2); + freeReg(r1); + } + } + private int addConstant(LuaValue v) { Integer existing = constantLookup.get(v); if (existing != null) return existing; @@ -340,6 +459,12 @@ private int nilK() { return addConstant(NIL); } + /** + * Fix an expression to return the number of results 'nresults'. + *

+ * Either 'e' is a multi-ret expression (function call or vararg) 'nresults' is LUA_MULTRET (as any expression can + * satisfy that). + */ void setReturns(ExpDesc e, int nresults) throws CompileException { if (e.kind == ExpKind.VCALL) { /* expression is an open function call? */ code[e.info] = SETARG_C(code[e.info], nresults + 1); @@ -347,9 +472,20 @@ void setReturns(ExpDesc e, int nresults) throws CompileException { int op = SETARG_B(code[e.info], nresults + 1); code[e.info] = SETARG_A(op, freeReg); reserveRegs(1); + } else { + assert nresults == LUA_MULTRET; } } + /** + * Fix an expression to return one result. + *

+ * If expression is not a multi-ret expression (function call or vararg), it already returns one result, so nothing + * needs to be done. Function calls become VNONRELOC expressions (as its result comes fixed in the base register of + * the call), while vararg expressions become VRELOCABLE (as OP_VARARG puts its results where it wants). + *

+ * (Calls are created returning one result, so that does not need to be fixed.) + */ void setOneRet(ExpDesc e) { if (e.kind == ExpKind.VCALL) { /* expression is an open function call? */ e.kind = ExpKind.VNONRELOC; @@ -360,6 +496,9 @@ void setOneRet(ExpDesc e) { } } + /** + * Ensure that expression 'e' is not a variable. + */ void dischargeVars(ExpDesc e) throws CompileException { switch (e.kind) { case VLOCAL -> e.kind = ExpKind.VNONRELOC; @@ -367,47 +506,52 @@ void dischargeVars(ExpDesc e) throws CompileException { e.info = codeABCAt(OP_GETUPVAL, 0, e.info, 0, e.position); e.kind = ExpKind.VRELOCABLE; } - case VGLOBAL -> { - e.info = codeABxAt(OP_GETGLOBAL, 0, e.info, e.position); - e.kind = ExpKind.VRELOCABLE; - } case VINDEXED -> { + int op; freeReg(e.aux); - freeReg(e.info); - e.info = codeABCAt(OP_GETTABLE, 0, e.info, e.aux, e.position); + if (e.tableType == ExpKind.VLOCAL) { + freeReg(e.info); + op = OP_GETTABLE; + } else { + assert e.tableType == ExpKind.VUPVAL; + op = OP_GETTABUP; + } + e.info = codeABCAt(op, 0, e.info, e.aux, e.position); e.kind = ExpKind.VRELOCABLE; } case VVARARG, VCALL -> setOneRet(e); default -> { - } /* there is one value available (somewhere) */ + /* there is one value available (somewhere) */ + } } } - private int codeLabel(int A, int b, int jump) throws CompileException { - getLabel(); // those instructions may be jump targets - return codeABC(OP_LOADBOOL, A, b, jump); - } - + /** + * Ensures expression value is in register 'reg' (and therefore 'e' will become a non-relocatable expression). + */ private void discharge2Reg(ExpDesc e, int reg) throws CompileException { dischargeVars(e); switch (e.kind) { case VNIL -> nil(reg, 1); case VFALSE, VTRUE -> codeABC(OP_LOADBOOL, reg, e.kind == ExpKind.VTRUE ? 1 : 0, 0); - case VK -> codeABx(OP_LOADK, reg, e.info); - case VKNUM -> codeABx(OP_LOADK, reg, numberK(e.nval())); + case VK -> codeK(reg, e.info); + case VKNUM -> codeK(reg, numberK(e.nval())); case VRELOCABLE -> code[e.info] = SETARG_A(code[e.info], reg); case VNONRELOC -> { if (reg != e.info) codeABC(OP_MOVE, reg, e.info, 0); } default -> { - _assert(e.kind == ExpKind.VVOID || e.kind == ExpKind.VJMP); - return; /* nothing to do... */ + assert e.kind == ExpKind.VJMP; + return; } } e.info = reg; e.kind = ExpKind.VNONRELOC; } + /** + * Ensures expression is in a register. + */ private void discharge2AnyReg(ExpDesc e) throws CompileException { if (e.kind != ExpKind.VNONRELOC) { reserveRegs(1); @@ -415,27 +559,49 @@ private void discharge2AnyReg(ExpDesc e) throws CompileException { } } + private int codeLoadBool(int A, int b, int jump) throws CompileException { + getLabel(); // those instructions may be jump targets + return codeABC(OP_LOADBOOL, A, b, jump); + } + + private boolean needValue(int list) { + for (; list != NO_JUMP; list = getJump(list)) { + if (GET_OPCODE(code[getJumpControl(list)]) != OP_TESTSET) return true; + } + return false; // not found + } + + /** + * Ensures final expression result (including results from its jump + * lists) is in register 'reg'. + * If expression has jumps, need to patch these jumps either to + * its final position or to "load" instructions (for those tests + * that do not produce values). + */ private void exp2reg(ExpDesc e, int reg) throws CompileException { discharge2Reg(e, reg); if (e.kind == ExpKind.VJMP) concat(e.t, e.info); /* put this jump in `t' list */ if (e.hasjumps()) { - int p_f = NO_JUMP; // position of an eventual LOAD false - int p_t = NO_JUMP; // position of an eventual LOAD true + int pcFalse = NO_JUMP; // position of an eventual LOAD false + int psTrue = NO_JUMP; // position of an eventual LOAD true if (needValue(e.t.value) || needValue(e.f.value)) { int fj = e.kind == ExpKind.VJMP ? NO_JUMP : jump(); - p_f = codeLabel(reg, 0, 1); - p_t = codeLabel(reg, 1, 0); + pcFalse = codeLoadBool(reg, 0, 1); + psTrue = codeLoadBool(reg, 1, 0); patchToHere(fj); } - int _final = getLabel(); // position after whole expression - patchListAux(e.f.value, _final, reg, p_f); - patchListAux(e.t.value, _final, reg, p_t); + int end = getLabel(); // position after whole expression + patchListAux(e.f.value, end, reg, pcFalse); + patchListAux(e.t.value, end, reg, psTrue); } e.f.value = e.t.value = NO_JUMP; e.info = reg; e.kind = ExpKind.VNONRELOC; } + /** + * Ensures final expression result (including results from its jump lists) is in next available register. + */ void exp2NextReg(ExpDesc e) throws CompileException { dischargeVars(e); freeExp(e); @@ -443,6 +609,10 @@ void exp2NextReg(ExpDesc e) throws CompileException { exp2reg(e, freeReg - 1); } + /** + * Ensures final expression result (including results from its jump lists) is in some (any) register and return + * that register. + */ int exp2AnyReg(ExpDesc e) throws CompileException { dischargeVars(e); if (e.kind == ExpKind.VNONRELOC) { @@ -456,6 +626,16 @@ int exp2AnyReg(ExpDesc e) throws CompileException { return e.info; } + /** + * Ensures final expression result is either in a register or in an upvalue. + */ + void exp2AnyRegUp(ExpDesc e) throws CompileException { + if (e.kind != ExpKind.VUPVAL || e.hasjumps()) exp2AnyReg(e); + } + + /** + * Ensures final expression result is either in a register or it is a constant. + */ void exp2Val(ExpDesc e) throws CompileException { if (e.hasjumps()) { exp2AnyReg(e); @@ -464,27 +644,25 @@ void exp2Val(ExpDesc e) throws CompileException { } } + /** + * Ensures final expression result is in a valid R/K index (that is, it is either in a register or in 'k' with an + * index in the range of R/K indices). Returns R/K index. + */ int exp2RK(ExpDesc e) throws CompileException { exp2Val(e); + + // Promote constants to VK. switch (e.kind) { - case VKNUM, VTRUE, VFALSE, VNIL -> { - if (constants.size() <= MAXINDEXRK) { /* constant fit in RK operand? */ - e.info = e.kind == ExpKind.VNIL ? nilK() - : e.kind == ExpKind.VKNUM ? numberK(e.nval()) - : boolK(e.kind == ExpKind.VTRUE); - e.kind = ExpKind.VK; - return RKASK(e.info); - } - } - case VK -> { - if (e.info <= MAXINDEXRK) /* constant fit in argC? */ { - return RKASK(e.info); - } - } - default -> { - } + case VNIL -> e.setConstant(nilK()); + case VTRUE -> e.setConstant(boolK(true)); + case VFALSE -> e.setConstant(boolK(false)); + case VKNUM -> e.setConstant(numberK(e.nval())); } - /* not a constant in the right range: put it in a register */ + + // Then use if they fit within the given range. + if (e.kind == ExpKind.VK && e.info <= MAXINDEXRK) return RKASK(e.info); + + // Otherwise doesn't fit in the range - promote to a register. return exp2AnyReg(e); } @@ -499,41 +677,41 @@ void storeVar(ExpDesc var, ExpDesc ex) throws CompileException { int e = exp2AnyReg(ex); codeABCAt(OP_SETUPVAL, e, var.info, 0, var.position); } - case VGLOBAL -> { - int e = exp2AnyReg(ex); - codeABxAt(OP_SETGLOBAL, e, var.info, var.position); - } case VINDEXED -> { + int opcode = var.tableType == ExpKind.VLOCAL ? OP_SETTABLE : OP_SETTABUP; int e = exp2RK(ex); - codeABCAt(OP_SETTABLE, var.info, var.aux, e, var.position); + codeABCAt(opcode, var.info, var.aux, e, var.position); } - default -> _assert(false); /* invalid var kind to store */ + default -> throw new AssertionError("invalid var kind to store"); } freeExp(ex); } void self(ExpDesc e, ExpDesc key) throws CompileException { - int func; exp2AnyReg(e); + int eReg = e.info; freeExp(e); - func = freeReg; + e.info = freeReg; + e.kind = ExpKind.VNONRELOC; reserveRegs(2); - codeABC(OP_SELF, func, e.info, exp2RK(key)); + codeABC(OP_SELF, e.info, eReg, exp2RK(key)); freeExp(key); - e.info = func; - e.kind = ExpKind.VNONRELOC; } - private void invertJump(ExpDesc e) throws CompileException { + private void negateCondition(ExpDesc e) { int pc = getJumpControl(e.info); - int op = code[pc]; - _assert(testTMode(GET_OPCODE(op)) && GET_OPCODE(op) != OP_TESTSET && GET_OPCODE(op) != OP_TEST); + int i = code[pc]; + assert testTMode(GET_OPCODE(i)) && GET_OPCODE(i) != OP_TESTSET && GET_OPCODE(i) != OP_TEST; - int a = GETARG_A(op); - int nota = a != 0 ? 0 : 1; - code[pc] = SETARG_A(op, nota); + code[pc] = SETARG_A(i, GETARG_A(i) != 0 ? 0 : 1); } + /** + * Emit instruction to jump if 'e' is 'cond' (that is, if 'cond' + * is true, code will jump if 'e' is true.) Return jump position. + * Optimize when 'e' is 'not' something, inverting the condition + * and removing the 'not'. + */ private int jumpOnCond(ExpDesc e, int cond) throws CompileException { if (e.kind == ExpKind.VRELOCABLE) { int ie = code[e.info]; @@ -548,32 +726,34 @@ private int jumpOnCond(ExpDesc e, int cond) throws CompileException { return condJump(OP_TESTSET, NO_REG, e.info, cond, lexer.lastPosition()); } + /** + * Emit code to go through if 'e' is true, jump otherwise. + */ void goIfTrue(ExpDesc e) throws CompileException { dischargeVars(e); - int pc; /* pc of last jump */ - switch (e.kind) { - case VK, VKNUM, VTRUE -> pc = NO_JUMP; /* always true; do nothing */ - case VFALSE -> pc = jump(); /* always jump */ + int pc = switch (e.kind) { /* pc of last jump */ + case VK, VKNUM, VTRUE -> NO_JUMP; /* always true; do nothing */ case VJMP -> { - invertJump(e); - pc = e.info; + negateCondition(e); + yield e.info; } - default -> pc = jumpOnCond(e, 0); - } + default -> jumpOnCond(e, 0); + }; concat(e.f, pc); /* insert last jump in `f' list */ patchToHere(e.t.value); e.t.value = NO_JUMP; } - private void goIfFalse(ExpDesc e) throws CompileException { + /** + * Emit code to go through if 'e' is false, jump otherwise. + */ + void goIfFalse(ExpDesc e) throws CompileException { dischargeVars(e); - int pc; /* pc of last jump */ - switch (e.kind) { - case VNIL, VFALSE -> pc = NO_JUMP; /* always false; do nothing */ - case VTRUE -> pc = jump(); /* always jump */ - case VJMP -> pc = e.info; - default -> pc = jumpOnCond(e, 1); - } + int pc = switch (e.kind) { /* pc of last jump */ + case VNIL, VFALSE -> NO_JUMP; /* always false; do nothing */ + case VJMP -> e.info; + default -> jumpOnCond(e, 1); + }; concat(e.t, pc); /* insert last jump in `t' list */ patchToHere(e.f.value); e.f.value = NO_JUMP; @@ -584,14 +764,14 @@ private void codeNot(ExpDesc e) throws CompileException { switch (e.kind) { case VNIL, VFALSE -> e.kind = ExpKind.VTRUE; case VK, VKNUM, VTRUE -> e.kind = ExpKind.VFALSE; - case VJMP -> invertJump(e); + case VJMP -> negateCondition(e); case VRELOCABLE, VNONRELOC -> { discharge2AnyReg(e); freeExp(e); e.info = codeABC(OP_NOT, 0, e.info, 0); e.kind = ExpKind.VRELOCABLE; } - default -> _assert(false); /* cannot happen */ + default -> throw new AssertionError(); } /* interchange true and false lists */ { @@ -604,12 +784,14 @@ private void codeNot(ExpDesc e) throws CompileException { } void indexed(ExpDesc t, ExpDesc k, long pos) throws CompileException { + assert !t.hasjumps() && (t.kind.isInRegister() || t.kind == ExpKind.VUPVAL); t.aux = exp2RK(k); + t.tableType = t.kind == ExpKind.VUPVAL ? ExpKind.VUPVAL : ExpKind.VLOCAL; t.kind = ExpKind.VINDEXED; t.position = pos; } - private boolean constFolding(int op, ExpDesc e1, ExpDesc e2) throws CompileException { + private static boolean constFolding(int op, ExpDesc e1, ExpDesc e2) { if (!e1.isnumeral() || !e2.isnumeral()) return false; double v1 = e1.nval().toDouble(); @@ -632,10 +814,7 @@ private boolean constFolding(int op, ExpDesc e1, ExpDesc e2) throws CompileExcep case OP_LEN -> { return false; /* no constant folding for 'len' */ } - default -> { - _assert(false); - return false; - } + default -> throw new AssertionError(); } if (Double.isNaN(r)) return false; /* do not attempt to produce NaN */ @@ -643,52 +822,65 @@ private boolean constFolding(int op, ExpDesc e1, ExpDesc e2) throws CompileExcep return true; } - private void codeArith(int op, ExpDesc e1, ExpDesc e2, long position) throws CompileException { - if (constFolding(op, e1, e2)) return; + /** + * Emit code for unary expressions that "produce values" (everything but 'not'). + *

+ * Expression to produce final result will be encoded in 'e'. + */ + private void codeUnaryExpression(int opcode, ExpDesc e, long position) throws CompileException { + int r = exp2AnyReg(e); + freeExp(e); + e.info = codeABCAt(opcode, 0, r, 0, position); + e.kind = ExpKind.VRELOCABLE; + } - int o2 = op != OP_UNM && op != OP_LEN ? exp2RK(e2) : 0; - int o1 = exp2RK(e1); - if (o1 > o2) { - freeExp(e1); - freeExp(e2); - } else { - freeExp(e2); - freeExp(e1); - } - e1.info = codeABCAt(op, 0, o1, o2, position); + /** + * Emit code for binary expressions that "produce values" (everything but logical operators 'and'/'or' and + * comparison operators). + *

+ * Expression to produce final result will be encoded in 'e1'. Because 'luaK_exp2RK' can free registers, its calls + * must be in "stack order" (that is, first on 'e2', which may have more recent registers to be released). + */ + private void codeBinaryExpression(int opcode, ExpDesc e1, ExpDesc e2, long position) throws CompileException { + int rk2 = exp2RK(e2); + int rk1 = exp2RK(e1); + freeExp(e1, e2); + e1.info = codeABCAt(opcode, 0, rk1, rk2, position); e1.kind = ExpKind.VRELOCABLE; } - private void codeComparison(int op, int cond, ExpDesc e1, ExpDesc e2, long position) throws CompileException { - int o1 = exp2RK(e1); - int o2 = exp2RK(e2); - freeExp(e2); - freeExp(e1); - if (cond == 0 && op != OP_EQ) { - int temp; /* exchange args to replace by `<' or `<=' */ - temp = o1; - o1 = o2; - o2 = temp; /* o1 <==> o2 */ - cond = 1; - } - e1.info = condJump(op, cond, o1, o2, position); + private void codeArith(int op, ExpDesc e1, ExpDesc e2, long position) throws CompileException { + if (!constFolding(op, e1, e2)) codeBinaryExpression(op, e1, e2, position); + } + + private void codeComparison(BinOpr op, ExpDesc e1, ExpDesc e2, long position) throws CompileException { + int rk1 = switch (e1.kind) { + case VK -> RKASK(e1.info); + case VNONRELOC -> e1.info; + default -> throw new AssertionError(); + }; + int rk2 = exp2RK(e2); + freeExp(e1, e2); + + e1.info = switch (op) { + case NE -> condJump(OP_EQ, 0, rk1, rk2, position); // a ~= b => not (a == b) + case GT -> condJump(OP_LT, 1, rk2, rk1, position); // a > b => b < a + case GE -> condJump(OP_LE, 1, rk2, rk1, position); // a >= b => b <= a + case EQ -> condJump(OP_EQ, 1, rk1, rk2, position); + case LT -> condJump(OP_LT, 1, rk1, rk2, position); + case LE -> condJump(OP_LE, 1, rk1, rk2, position); + default -> throw new AssertionError("not a comparison"); + }; e1.kind = ExpKind.VJMP; } void prefix(UnOpr op, ExpDesc e, long position) throws CompileException { - ExpDesc e2 = new ExpDesc(); - e2.init(ExpKind.VKNUM, 0); switch (op) { case MINUS -> { - if (e.kind == ExpKind.VK) exp2AnyReg(e); /* cannot operate on non-numeric constants */ - codeArith(OP_UNM, e, e2, position); + if (!constFolding(OP_UNM, e, e)) codeUnaryExpression(OP_UNM, e, position); } case NOT -> codeNot(e); - case LEN -> { - exp2AnyReg(e); /* cannot operate on constants */ - codeArith(OP_LEN, e, e2, position); - } - default -> _assert(false); + case LEN -> codeUnaryExpression(OP_LEN, e, position); } } @@ -707,13 +899,13 @@ void infix(BinOpr op, ExpDesc v) throws CompileException { void posfix(BinOpr op, ExpDesc e1, ExpDesc e2, long position) throws CompileException { switch (op) { case AND -> { - _assert(e1.t.value == NO_JUMP); /* list must be closed */ + assert e1.t.value == NO_JUMP; /* list must be closed */ dischargeVars(e2); concat(e2.f, e1.f.value); e1.setValue(e2); } case OR -> { - _assert(e1.f.value == NO_JUMP); /* list must be closed */ + assert e1.f.value == NO_JUMP; /* list must be closed */ dischargeVars(e2); concat(e2.t, e1.t.value); e1.setValue(e2); @@ -721,14 +913,14 @@ void posfix(BinOpr op, ExpDesc e1, ExpDesc e2, long position) throws CompileExce case CONCAT -> { exp2Val(e2); if (e2.kind == ExpKind.VRELOCABLE && GET_OPCODE(code[e2.info]) == OP_CONCAT) { - _assert(e1.info == GETARG_B(code[e2.info]) - 1); + assert e1.info == GETARG_B(code[e2.info]) - 1; freeExp(e1); code[e2.info] = SETARG_B(code[e2.info], e1.info); e1.kind = ExpKind.VRELOCABLE; e1.info = e2.info; } else { exp2NextReg(e2); /* operand must be on the 'stack' */ - codeArith(OP_CONCAT, e1, e2, position); + codeBinaryExpression(OP_CONCAT, e1, e2, position); } } case ADD -> codeArith(OP_ADD, e1, e2, position); @@ -737,13 +929,8 @@ void posfix(BinOpr op, ExpDesc e1, ExpDesc e2, long position) throws CompileExce case DIV -> codeArith(OP_DIV, e1, e2, position); case MOD -> codeArith(OP_MOD, e1, e2, position); case POW -> codeArith(OP_POW, e1, e2, position); - case EQ -> codeComparison(OP_EQ, 1, e1, e2, position); - case NE -> codeComparison(OP_EQ, 0, e1, e2, position); - case LT -> codeComparison(OP_LT, 1, e1, e2, position); - case LE -> codeComparison(OP_LE, 1, e1, e2, position); - case GT -> codeComparison(OP_LT, 0, e1, e2, position); - case GE -> codeComparison(OP_LE, 0, e1, e2, position); - default -> _assert(false); + case EQ, NE, LT, LE, GT, GE -> codeComparison(op, e1, e2, position); + default -> throw new AssertionError("Unknown op " + op); } } @@ -752,59 +939,16 @@ void fixPosition(long position) { columnInfo[pc - 1] = Lex.unpackColumn(position); } - private int code(int instruction, long position) throws CompileException { - dischargeJumpPc(); /* `pc' will change */ - - // put new instruction in code array - if (code == null || pc + 1 > code.length) code = LuaC.realloc(code, pc * 2 + 1); - code[pc] = instruction; - - // save corresponding line information - if (lineInfo == null || pc + 1 > lineInfo.length) { - lineInfo = LuaC.realloc(lineInfo, pc * 2 + 1); - columnInfo = LuaC.realloc(columnInfo, pc * 2 + 1); - } - lineInfo[pc] = Lex.unpackLine(position); - columnInfo[pc] = Lex.unpackColumn(position); - - return pc++; - } - - int codeABCAt(int o, int a, int b, int c, long position) throws CompileException { - _assert(getOpMode(o) == iABC); - _assert(getBMode(o) != OpArgN || b == 0); - _assert(getCMode(o) != OpArgN || c == 0); - _assert(position > 0); - return code(CREATE_ABC(o, a, b, c), position); - } - - int codeABC(int o, int a, int b, int c) throws CompileException { - _assert(getOpMode(o) == iABC); - _assert(getBMode(o) != OpArgN || b == 0); - _assert(getCMode(o) != OpArgN || c == 0); - return code(CREATE_ABC(o, a, b, c), lexer.lastPosition()); - } - - int codeABxAt(int o, int a, int bc, long position) throws CompileException { - _assert(getOpMode(o) == iABx || getOpMode(o) == iAsBx); - _assert(getCMode(o) == OpArgN); - _assert(position > 0); - return code(CREATE_ABx(o, a, bc), position); - } - - int codeABx(int o, int a, int bc) throws CompileException { - return codeABxAt(o, a, bc, lexer.lastPosition()); - } void setList(int base, int nelems, int tostore) throws CompileException { int c = (nelems - 1) / LFIELDS_PER_FLUSH + 1; int b = tostore == LUA_MULTRET ? 0 : tostore; - _assert(tostore != 0); + assert tostore != 0; if (c <= MAXARG_C) { codeABC(OP_SETLIST, base, b, c); } else { codeABC(OP_SETLIST, base, b, 0); - code(c, lexer.lastPosition()); + code(CREATE_Ax(OP_EXTRAARG, c), lexer.lastPosition()); } freeReg = base + 1; /* free registers with list values */ } diff --git a/src/main/java/org/squiddev/cobalt/compiler/Lex.java b/src/main/java/org/squiddev/cobalt/compiler/Lex.java index 9510dcdb..cc118f42 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/Lex.java +++ b/src/main/java/org/squiddev/cobalt/compiler/Lex.java @@ -27,22 +27,22 @@ final class Lex { // Terminal symbols denoted by reserved words static final int TK_AND = 257, TK_BREAK = 258, TK_DO = 259, TK_ELSE = 260, TK_ELSEIF = 261, - TK_END = 262, TK_FALSE = 263, TK_FOR = 264, TK_FUNCTION = 265, TK_IF = 266, - TK_IN = 267, TK_LOCAL = 268, TK_NIL = 269, TK_NOT = 270, TK_OR = 271, TK_REPEAT = 272, - TK_RETURN = 273, TK_THEN = 274, TK_TRUE = 275, TK_UNTIL = 276, TK_WHILE = 277; + TK_END = 262, TK_FALSE = 263, TK_FOR = 264, TK_FUNCTION = 265, TK_GOTO = 266, TK_IF = 267, + TK_IN = 268, TK_LOCAL = 269, TK_NIL = 270, TK_NOT = 271, TK_OR = 272, TK_REPEAT = 273, + TK_RETURN = 274, TK_THEN = 275, TK_TRUE = 276, TK_UNTIL = 277, TK_WHILE = 278; // Other terminal symbols static final int - TK_CONCAT = 278, TK_DOTS = 279, TK_EQ = 280, TK_GE = 281, TK_LE = 282, TK_NE = 283, - TK_EOS = 284, - TK_NUMBER = 285, TK_NAME = 286, TK_STRING = 287; + TK_CONCAT = 279, TK_DOTS = 280, TK_EQ = 281, TK_GE = 282, TK_LE = 283, TK_NE = 284, TK_DBCOLON = 285, + TK_EOS = 286, + TK_NUMBER = 287, TK_NAME = 288, TK_STRING = 289; /* Token names: this must be consistent with the list above. */ private final static String[] tokenNames = { "and", "break", "do", "else", "elseif", - "end", "false", "for", "function", "if", + "end", "false", "for", "function", "goto", "if", "in", "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while", - "..", "...", "==", ">=", "<=", "~=", + "..", "...", "==", ">=", "<=", "~=", "::", "", "", "", "", }; @@ -55,11 +55,18 @@ final class Lex { static { Map reserved = new HashMap<>(); for (int i = 0; i < NUM_RESERVED; i++) { + // We skip GOTO and inject it later on when parsing statements. + if (FIRST_RESERVED + i == TK_GOTO) continue; + reserved.put(ValueFactory.valueOf(tokenNames[i]).toBuffer(), FIRST_RESERVED + i); } RESERVED = Collections.unmodifiableMap(reserved); } + static boolean isReserved(LuaString name) { + return RESERVED.containsKey(name.toBuffer()); + } + static class Token { private int token; private LuaValue value; @@ -532,6 +539,10 @@ private int lexToken(Token token) throws CompileException, LuaError, UnwindThrow next(); return checkNext('=') ? TK_NE : '~'; } + case ':' -> { + next(); + return checkNext(':') ? TK_DBCOLON : ':'; + } case '"', '\'' -> { token.value = readString(current); return TK_STRING; diff --git a/src/main/java/org/squiddev/cobalt/compiler/LoadState.java b/src/main/java/org/squiddev/cobalt/compiler/LoadState.java index 1f1f30c3..ff689cb5 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/LoadState.java +++ b/src/main/java/org/squiddev/cobalt/compiler/LoadState.java @@ -76,7 +76,7 @@ public interface FunctionFactory { * @param env The function's environment. * @return The loaded function */ - LuaClosure load(Prototype prototype, LuaTable env); + LuaClosure load(Prototype prototype, LuaValue env); } private LoadState() { @@ -85,13 +85,14 @@ private LoadState() { /** * A basic {@link FunctionFactory} which loads into */ - public static LuaClosure interpretedFunction(Prototype prototype, LuaTable env) { - LuaInterpretedFunction closure = new LuaInterpretedFunction(prototype, env); + public static LuaClosure interpretedFunction(Prototype prototype, LuaValue env) { + LuaInterpretedFunction closure = new LuaInterpretedFunction(prototype); closure.nilUpvalues(); + if (closure.upvalues.length > 0) closure.upvalues[0].setValue(env); return closure; } - public static LuaClosure load(LuaState state, InputStream stream, String name, LuaTable env) throws CompileException, LuaError { + public static LuaClosure load(LuaState state, InputStream stream, String name, LuaValue env) throws CompileException, LuaError { return load(state, stream, valueOf(name), env); } @@ -106,11 +107,11 @@ public static LuaClosure load(LuaState state, InputStream stream, String name, L * @throws IllegalArgumentException If the signature is bac * @throws CompileException If the stream cannot be loaded. */ - public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaTable env) throws CompileException, LuaError { + public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaValue env) throws CompileException, LuaError { return load(state, stream, name, null, env); } - public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaString mode, LuaTable env) throws CompileException, LuaError { + public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaString mode, LuaValue env) throws CompileException, LuaError { return state.compiler.load(LuaC.compile(state, stream, name, mode), env); } diff --git a/src/main/java/org/squiddev/cobalt/compiler/LuaBytecodeFormat.java b/src/main/java/org/squiddev/cobalt/compiler/LuaBytecodeFormat.java index 4f92de0a..edfd0756 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/LuaBytecodeFormat.java +++ b/src/main/java/org/squiddev/cobalt/compiler/LuaBytecodeFormat.java @@ -15,15 +15,19 @@ public final class LuaBytecodeFormat implements BytecodeFormat { static final byte[] LUA_SIGNATURE = {27, 'L', 'u', 'a'}; /** - * The current Lua bytecode format, currently Lua 5.1 + * The current Lua bytecode format, currently Lua 5.2 */ - static final int LUAC_VERSION = 0x51; + static final int LUAC_VERSION = 0x52; /** * The format for binary files. 0 denotes the "official" format. */ static final int LUAC_FORMAT = 0; + /** + * Tail after the header, to catch conversion errors. + */ + static final byte[] LUAC_TAIL = {0x19, (byte) 0x93, '\r', '\n', 0x1a, '\n'}; private static final LuaBytecodeFormat INSTANCE = new LuaBytecodeFormat(); @@ -47,7 +51,7 @@ public SuspendedFunction readFunction(LuaString name, InputReader inp try { loader.checkSignature(); loader.loadHeader(); - return loader.loadFunction(LoadState.getSourceName(name)); + return loader.loadFunction(); } catch (CompileException e) { throw throwUnchecked0(e); } diff --git a/src/main/java/org/squiddev/cobalt/compiler/LuaC.java b/src/main/java/org/squiddev/cobalt/compiler/LuaC.java index ac0985fd..58123f5b 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/LuaC.java +++ b/src/main/java/org/squiddev/cobalt/compiler/LuaC.java @@ -116,6 +116,11 @@ public static int CREATE_ABx(int o, int a, int bc) { ((bc << Lua.POS_Bx) & Lua.MASK_Bx); } + public static int CREATE_Ax(int o, int a) { + return ((o << Lua.POS_OP) & Lua.MASK_OP) | + ((a << Lua.POS_Ax) & Lua.MASK_Ax); + } + public static int[] realloc(int[] v, int n) { int[] a = new int[n]; if (v != null) System.arraycopy(v, 0, a, 0, Math.min(v.length, n)); @@ -128,6 +133,12 @@ public static byte[] realloc(byte[] v, int n) { return a; } + public static short[] realloc(short[] v, int n) { + short[] a = new short[n]; + if (v != null) System.arraycopy(v, 0, a, 0, Math.min(v.length, n)); + return a; + } + private LuaC() { } @@ -179,16 +190,7 @@ public static Prototype compile(LuaState state, InputReader stream, LuaString na private static Prototype loadTextChunk(int firstByte, InputReader stream, LuaString name) throws CompileException, LuaError, UnwindThrowable { Parser parser = new Parser(stream, firstByte, name, LoadState.getShortName(name)); parser.lexer.skipShebang(); - FuncState funcstate = parser.openFunc(); - funcstate.varargFlags = Lua.VARARG_ISVARARG; /* main func. is always vararg */ - - parser.lexer.nextToken(); // read first token - parser.chunk(); - parser.check(Lex.TK_EOS); - Prototype prototype = parser.closeFunc(); - LuaC._assert(funcstate.upvalues.size() == 0); - LuaC._assert(parser.fs == null); - return prototype; + return parser.mainFunction(); } public record InputStreamReader(InputStream stream) implements InputReader { diff --git a/src/main/java/org/squiddev/cobalt/compiler/Parser.java b/src/main/java/org/squiddev/cobalt/compiler/Parser.java index 2ec1bbde..f4d1299d 100644 --- a/src/main/java/org/squiddev/cobalt/compiler/Parser.java +++ b/src/main/java/org/squiddev/cobalt/compiler/Parser.java @@ -24,10 +24,14 @@ */ package org.squiddev.cobalt.compiler; +import org.checkerframework.checker.nullness.qual.Nullable; import org.squiddev.cobalt.*; import org.squiddev.cobalt.function.LocalVariable; import org.squiddev.cobalt.unwind.AutoUnwind; +import java.util.ArrayList; +import java.util.List; + import static org.squiddev.cobalt.Lua.*; import static org.squiddev.cobalt.compiler.Lex.*; import static org.squiddev.cobalt.compiler.LuaC.LUAI_MAXUPVALUES; @@ -53,11 +57,29 @@ private static String LUA_QL(String s) { static final int NO_JUMP = -1; final Lex lexer; + final LuaString envName, gotoName; FuncState fs; public int nCcalls; + /** + * A stack of active labels in scope. The labels available in the current block are all those at or above index + * {@link FuncState.BlockCnt#firstLabel}. + */ + private final List activeLabels = new ArrayList<>(); + + /** + * A stack of unsolved gotos in scope. The gotos defined in the current block are all those at or above index + * {@link FuncState.BlockCnt#firstLabel}. + */ + private final List pendingGotos = new ArrayList<>(); + + private int activeVariableSize; + private short[] activeVariables = new short[16]; + public Parser(InputReader stream, int firstByte, LuaString source, LuaString shortSource) { lexer = new Lex(source, shortSource, stream, firstByte); + envName = lexer.newString("_ENV"); + gotoName = lexer.newString("goto"); fs = null; } @@ -72,6 +94,7 @@ static class ExpDesc { private LuaNumber nval; int info; int aux; + @Nullable ExpKind tableType; final IntPtr t = new IntPtr(); /* patch list of `exit when true' */ final IntPtr f = new IntPtr(); /* patch list of `exit when false' */ @@ -107,6 +130,34 @@ public void setValue(ExpDesc other) { t.value = other.t.value; f.value = other.f.value; } + + public void setConstant(int constantIndex) { + kind = ExpKind.VK; + info = constantIndex; + } + } + + /** + * A description for a label. + */ + static final class LabelDesc { + private final LuaString name; + private final int pc; + private final int line; + private short activeVariables; + + /** + * @param name Label identifier + * @param pc Position in code + * @param line Line where it appeared + * @param activeVariables Local level where it appears in current block + */ + LabelDesc(LuaString name, int pc, int line, short activeVariables) { + this.name = name; + this.pc = pc; + this.line = line; + this.activeVariables = activeVariables; + } } /*---------------------------------------------------------------------- @@ -118,6 +169,10 @@ CompileException syntaxError(String msg) { return lexer.syntaxError(msg); } + CompileException semError(String msg) { + return lexer.lexError(msg, 0); + } + private void errorExpected(int token) throws CompileException { throw syntaxError(token2str(token) + " expected"); } @@ -187,67 +242,80 @@ private int registerLocal(LuaString name) { return index; } - private void newLocal(String name, int n) throws CompileException { - newLocal(lexer.newString(name), n); + private void newLocal(String name) throws CompileException { + newLocal(lexer.newString(name)); } - private void newLocal(LuaString name, int n) throws CompileException { + private void newLocal(LuaString name) throws CompileException { FuncState fs = this.fs; - checkLimit(fs, fs.activeVariableCount + n + 1, LuaC.LUAI_MAXVARS, "local variables"); - fs.activeVariables[fs.activeVariableCount + n] = (short) registerLocal(name); + checkLimit(fs, activeVariableSize + 1, LuaC.LUAI_MAXVARS, "local variables"); + if (activeVariableSize >= activeVariables.length) { + activeVariables = LuaC.realloc(activeVariables, activeVariables.length * 2 + 1); + } + activeVariables[activeVariableSize++] = (short) registerLocal(name); + } + + LocalVariable getLocal(FuncState fs, int i) { + return fs.locals.get(activeVariables[fs.firstLocal + i]); } private void adjustLocalVars(int nVars) { FuncState fs = this.fs; fs.activeVariableCount = (short) (fs.activeVariableCount + nVars); for (; nVars > 0; nVars--) { - fs.getLocal(fs.activeVariableCount - nVars).startpc = fs.pc; + getLocal(fs, fs.activeVariableCount - nVars).startpc = fs.pc; } } void removeVars(int toLevel) { FuncState fs = this.fs; + activeVariableSize -= fs.activeVariableCount - toLevel; while (fs.activeVariableCount > toLevel) { - fs.getLocal(--fs.activeVariableCount).endpc = fs.pc; + getLocal(fs, --fs.activeVariableCount).endpc = fs.pc; } } - private int indexUpvalue(FuncState fs, LuaString name, ExpDesc v) throws CompileException { + private int searchUpvalue(FuncState fs, LuaString name) { // Search for an existing upvalue for (int i = 0; i < fs.upvalues.size(); i++) { - FuncState.UpvalueDesc upvalue = fs.upvalues.get(i); - if (upvalue.kind == v.kind && upvalue.info == v.info) { - assert upvalue.name == name; - return i; - } + if (fs.upvalues.get(i).name() == name) return i; } + return -1; + } + + private int newUpvalue(FuncState fs, LuaString name, ExpDesc v) throws CompileException { // Add a new upvalue checkLimit(fs, fs.upvalues.size(), LUAI_MAXUPVALUES, "upvalues"); assert v.kind == ExpKind.VLOCAL || v.kind == ExpKind.VUPVAL; int index = fs.upvalues.size(); - fs.upvalues.add(new FuncState.UpvalueDesc(name, v.kind, (short) v.info)); + fs.upvalues.add(new Prototype.UpvalueInfo(name, v.kind == ExpKind.VLOCAL, (byte) v.info)); return index; } private int searchVar(FuncState fs, LuaString n) { for (int i = fs.activeVariableCount - 1; i >= 0; i--) { - if (n == fs.getLocal(i).name) return i; + if (n == getLocal(fs, i).name) return i; } - return -1; /* not found */ + return -1; } + /** + * Mark block where variable at given level was defined (to emit close instructions later). + */ private void markUpvalue(FuncState fs, int level) { FuncState.BlockCnt block = fs.block; - while (block != null && block.nactvar > level) block = block.previous; + while (block != null && block.activeVariableCount > level) block = block.previous; if (block != null) block.upval = true; } + /** + * Find variable with given name 'n'. If it is an upvalue, add this upvalue into all intermediate functions. + */ private ExpKind singleVarAux(FuncState fs, LuaString n, ExpDesc var, boolean base) throws CompileException { if (fs == null) { // No more levels - var.init(ExpKind.VGLOBAL, NO_REG); - return ExpKind.VGLOBAL; + return ExpKind.VVOID; // Default is global } int v = searchVar(fs, n); // look up at current level @@ -256,11 +324,15 @@ private ExpKind singleVarAux(FuncState fs, LuaString n, ExpDesc var, boolean bas if (!base) markUpvalue(fs, v); // local will be used as an upvalue return ExpKind.VLOCAL; } else { - // not found at current level; try upper one - if (singleVarAux(fs.prev, n, var, false) == ExpKind.VGLOBAL) return ExpKind.VGLOBAL; + // Not found at current level, try an upvalue + int index = searchUpvalue(fs, n); + if (index < 0) { + // No such upvalue, search in the parent. + if (singleVarAux(fs.prev, n, var, false) == ExpKind.VVOID) return ExpKind.VVOID; + index = newUpvalue(fs, n, var); // Add the new upvalue. + } - var.info = indexUpvalue(fs, n, var); // else was LOCAL or UPVAL - var.kind = ExpKind.VUPVAL; // upvalue in this level + var.init(ExpKind.VUPVAL, index); return ExpKind.VUPVAL; } } @@ -268,10 +340,14 @@ private ExpKind singleVarAux(FuncState fs, LuaString n, ExpDesc var, boolean bas private void singleVar(ExpDesc var) throws CompileException, LuaError, UnwindThrowable { var.position = lexer.token.position(); - LuaString varname = strCheckName(); + LuaString varName = strCheckName(); FuncState fs = this.fs; - if (singleVarAux(fs, varname, var, true) == ExpKind.VGLOBAL) { - var.info = fs.stringK(varname); // info points to global name + if (singleVarAux(fs, varName, var, true) == ExpKind.VVOID) { // Looking up a global + var key = new ExpDesc(); + var global = singleVarAux(fs, envName, var, true); // var := _ENV + assert global == ExpKind.VLOCAL || global == ExpKind.VUPVAL; + codeString(key, varName); // key := $varName + fs.indexed(var, key, var.position); // var := var[key] } } @@ -291,63 +367,178 @@ private void adjustAssign(int nvars, int nexps, ExpDesc e) throws CompileExcepti fs.nil(reg, extra); } } + + if (nexps > nvars) fs.freeReg -= nexps - nvars; } private void enterLevel() throws CompileException { - if (++nCcalls > LUAI_MAXCCALLS) { - throw lexer.lexError("chunk has too many syntax levels", 0); - } + nCcalls++; + checkLimit(fs, nCcalls, LUAI_MAXCCALLS, "syntax levels"); } private void leaveLevel() { nCcalls--; } - private void enterBlock(FuncState fs, FuncState.BlockCnt bl, boolean isbreakable) throws CompileException { - bl.breaklist.value = Parser.NO_JUMP; - bl.isbreakable = isbreakable; - bl.nactvar = fs.activeVariableCount; + /** + * Solves the goto at index 'g' to given 'label' and removes it from the list of pending gotos. + *

+ * If it jumps into the scope of some variable, raises an error. + */ + private void closeGoto(int g, LabelDesc label) throws CompileException { + var gt = pendingGotos.get(g); + assert gt.name == label.name; + + if (gt.activeVariables < label.activeVariables) { + var local = getLocal(fs, gt.activeVariables).name; + throw semError(" at line " + gt.line + " jumps into the scope of local '" + local + "'"); + } + + fs.patchList(gt.pc, label.pc); + pendingGotos.remove(g); + } + + /** + * Try to close a goto with existing labels; this solves backward jumps + * + * @param g The goto index. + * @return If a label was found. + */ + private boolean findLabel(int g) throws CompileException { + var block = fs.block; + var gt = pendingGotos.get(g); + for (int i = block.firstLabel; i < activeLabels.size(); i++) { + var label = activeLabels.get(i); + if (label.name != gt.name) continue; + + assert block.upval || activeLabels.size() > block.firstLabel; + if (gt.activeVariables > label.activeVariables) fs.patchClose(gt.pc, label.activeVariables); + closeGoto(g, label); + return true; + } + + return false; + } + + private int newLabelEntry(List labels, LuaString label, int line, int pc) throws CompileException { + checkLimit(fs, labels.size() + 1, Short.MAX_VALUE, "labels/gotos"); + int index = labels.size(); + labels.add(new LabelDesc(label, pc, line, fs.activeVariableCount)); + return index; + } + + /** + * Add a label with the given name + * + * @param label The label whose corresponding gotos should be "closed". + */ + private void findGotos(LabelDesc label) throws CompileException { + int i = fs.block.firstGoto; + + while (i < pendingGotos.size()) { + if (pendingGotos.get(i).name == label.name) { + closeGoto(i, label); + } else { + i++; + } + } + } + + /** + * Export pending gotos to outer level, to check them against outer labels; if the block being exited has upvalues, + * and the goto exits the scope of any variable (which can be the upvalue), close those variables being exited. + * + * @param bl The current block. + */ + private void moveGotosOut(FuncState.BlockCnt bl) throws CompileException { + for (int i = bl.firstGoto; i < pendingGotos.size(); ) { + var gt = pendingGotos.get(i); + if (gt.activeVariables > bl.activeVariableCount) { + if (bl.upval) fs.patchClose(gt.pc, bl.activeVariableCount); + gt.activeVariables = bl.activeVariableCount; + } + if (!findLabel(i)) i++; + } + } + + private void enterBlock(FuncState fs, boolean isLoop) { + enterBlock(fs, new FuncState.BlockCnt(), isLoop); + } + + private void enterBlock(FuncState fs, FuncState.BlockCnt bl, boolean isLoop) { + bl.isLoop = isLoop; + bl.activeVariableCount = fs.activeVariableCount; + bl.firstLabel = (short) activeLabels.size(); + bl.firstGoto = (short) pendingGotos.size(); bl.upval = false; bl.previous = fs.block; fs.block = bl; assert fs.freeReg == fs.activeVariableCount; } + /** + * Create a label named 'break' to resolve break statements + */ + private void breakLabel() throws CompileException { + var label = newLabelEntry(activeLabels, lexer.newString("break"), 0, fs.pc); + findGotos(activeLabels.get(label)); + } + + private static CompileException undefGoto(LabelDesc gt) { + return new CompileException(Lex.isReserved(gt.name) + ? "<" + gt.name + "> at line " + gt.line + " not inside a loop" + : "no visible label '" + gt.name + "' for at line " + gt.line + ); + } + private void leaveBlock(FuncState fs) throws CompileException { FuncState.BlockCnt bl = fs.block; + if (bl.previous != null && bl.upval) { + int j = fs.jump(); + fs.patchClose(j, bl.activeVariableCount); + fs.patchToHere(j); + } + if (bl.isLoop) breakLabel(); + fs.block = bl.previous; - removeVars(bl.nactvar); - if (bl.upval) fs.codeABC(OP_CLOSE, bl.nactvar, 0, 0); - // a block either controls scope or breaks (never both) - assert !bl.isbreakable || !bl.upval; - assert bl.nactvar == fs.activeVariableCount; + removeVars(bl.activeVariableCount); + + assert bl.activeVariableCount == fs.activeVariableCount; fs.freeReg = fs.activeVariableCount; // free registers - fs.patchToHere(bl.breaklist.value); + while (activeLabels.size() > bl.firstLabel) activeLabels.remove(activeLabels.size() - 1); + if (bl.previous != null) { + moveGotosOut(bl); // Update pending gotos to outer block + } else if (bl.firstGoto < pendingGotos.size()) { + // We have an unsolved goto - error. + throw undefGoto(pendingGotos.get(bl.firstGoto)); + } } - private void pushClosure(FuncState child, Prototype childPrototype, ExpDesc v) throws CompileException { + /** + * Codes instruction to create new closure in parent function. + */ + private void codeClosure(Prototype childPrototype, ExpDesc v) throws CompileException { FuncState current = fs; int index = current.children.size(); current.children.add(childPrototype); - v.init(ExpKind.VRELOCABLE, current.codeABx(Lua.OP_CLOSURE, 0, index)); - for (FuncState.UpvalueDesc upvalue : child.upvalues) { - int op = upvalue.kind == ExpKind.VLOCAL ? Lua.OP_MOVE : Lua.OP_GETUPVAL; - current.codeABC(op, 0, upvalue.info, 0); - } + v.init(ExpKind.VRELOCABLE, current.codeABx(OP_CLOSURE, 0, index)); + fs.exp2NextReg(v); } - FuncState openFunc() { - FuncState fs = new FuncState(lexer, this.fs); + FuncState openFunc() throws CompileException { + if (fs != null) checkLimit(fs, fs.children.size(), MAXARG_Bx, "functions"); + FuncState fs = new FuncState(lexer, this.fs, activeVariableSize); this.fs = fs; + enterBlock(fs, false); return fs; } Prototype closeFunc() throws CompileException { FuncState fs = this.fs; - removeVars(0); - fs.ret(0, 0); /* final return */ + fs.ret(0, 0); // final return + leaveBlock(fs); assert fs.block == null; this.fs = fs.prev; @@ -358,10 +549,34 @@ Prototype closeFunc() throws CompileException { /* GRAMMAR RULES */ /*============================================================*/ - private void field(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { + /** + * Check whether current token is in the follow set of a block. + *

+ * {@code until} closes syntactical blocks, but do not close scope, so it is handled separatly. + */ + private boolean blockFollow(boolean withUntil) { + return switch (lexer.token.token()) { + case TK_ELSE, TK_ELSEIF, TK_END, TK_EOS -> true; + case TK_UNTIL -> withUntil; + default -> false; + }; + } + + private void statementList() throws CompileException, LuaError, UnwindThrowable { + /* statlist -> { stat [`;'] } */ + while (!blockFollow(true)) { + if (lexer.token.token() == TK_RETURN) { + statement(); + return; + } + statement(); + } + } + + private void fieldSelect(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { /* field -> ['.' | ':'] NAME */ ExpDesc key = new ExpDesc(); - fs.exp2AnyReg(v); + fs.exp2AnyRegUp(v); long indexPos = lexer.token.position(); lexer.nextToken(); // skip the dot or colon checkName(key); @@ -444,6 +659,22 @@ private void listField(ConsControl cc) throws CompileException, LuaError, Unwind cc.toStore++; } + private void field(ConsControl cc) throws CompileException, LuaError, UnwindThrowable { + // field -> listfield | recfield + switch (lexer.token.token()) { + case TK_NAME -> { // may be listfields or recfields + lexer.lookahead(); + if (lexer.lookahead.token() != '=') { // expression? + listField(cc); + } else { + recordField(cc); + } + } + case '[' -> recordField(cc); + default -> listField(cc); + } + } + private void constructor(ExpDesc t) throws CompileException, LuaError, UnwindThrowable { /* constructor -> ?? */ FuncState fs = this.fs; @@ -457,24 +688,8 @@ private void constructor(ExpDesc t) throws CompileException, LuaError, UnwindThr do { assert cc.v.kind == ExpKind.VVOID || cc.toStore > 0; if (lexer.token.token() == '}') break; - closeListField(fs, cc); - switch (lexer.token.token()) { - case TK_NAME -> { // may be listfields or recfields - lexer.lookahead(); - if (lexer.lookahead.token() != '=') { // expression? - listField(cc); - } else { - recordField(cc); - } - } - case '[' -> { // constructor_item -> recfield - recordField(cc); - } - default -> { // constructor_part -> listfield - listField(cc); - } - } + field(cc); } while (testNext(',') || testNext(';')); checkMatch('}', '{', line); lastListField(fs, cc); @@ -509,27 +724,23 @@ private void parlist() throws CompileException, LuaError, UnwindThrowable { /* parlist -> [ param { `,' param } ] */ FuncState fs = this.fs; int numParams = 0; - fs.varargFlags = 0; if (lexer.token.token() != ')') { // is `parlist' not empty? do { switch (lexer.token.token()) { case TK_NAME -> { // param . NAME - newLocal(strCheckName(), numParams++); + newLocal(strCheckName()); + numParams++; } case TK_DOTS -> { // param . `...' lexer.nextToken(); - if (LUA_COMPAT_VARARG) { // use `arg' as default name - newLocal("arg", numParams++); - fs.varargFlags = Lua.VARARG_HASARG | Lua.VARARG_NEEDSARG; - } - fs.varargFlags |= Lua.VARARG_ISVARARG; + fs.isVararg = true; } default -> throw syntaxError(" or " + LUA_QL("...") + " expected"); } - } while (fs.varargFlags == 0 && testNext(',')); + } while (!fs.isVararg && testNext(',')); } adjustLocalVars(numParams); - fs.numParams = fs.activeVariableCount - (fs.varargFlags & Lua.VARARG_HASARG); + fs.numParams = fs.activeVariableCount; fs.reserveRegs(fs.activeVariableCount); /* reserve register for parameters */ } @@ -539,16 +750,17 @@ private void body(ExpDesc e, boolean needSelf, int line) throws CompileException newFuncState.lineDefined = line; checkNext('('); if (needSelf) { - newLocal("self", 0); + newLocal("self"); adjustLocalVars(1); } parlist(); checkNext(')'); - chunk(); + statementList(); newFuncState.lastLineDefined = lexer.token.line(); checkMatch(TK_END, TK_FUNCTION, line); + Prototype proto = closeFunc(); - pushClosure(newFuncState, proto, e); + codeClosure(proto, e); } private int expList1(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { @@ -570,12 +782,8 @@ private void funcArgs(ExpDesc f) throws CompileException, LuaError, UnwindThrowa int line = lexer.token.line(); switch (lexer.token.token()) { case '(' -> { /* funcargs -> `(' [ explist1 ] `)' */ - if (line != lexer.lastLine()) { - throw syntaxError("ambiguous syntax (function call x new statement)"); - } - lexer.nextToken(); - if (lexer.token.token() == ')') /* arg list is empty? */ { + if (lexer.token.token() == ')') { // arg list is empty? args.kind = ExpKind.VVOID; } else { expList1(args); @@ -613,8 +821,8 @@ private void funcArgs(ExpDesc f) throws CompileException, LuaError, UnwindThrowa ** ======================================================================= */ - private void prefixExpression(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { - /* prefixexp -> NAME | '(' expr ')' */ + private void primaryExpression(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { + // primaryexp -> NAME | '(' expr ')' switch (lexer.token.token()) { case '(' -> { int line = lexer.token.line(); @@ -622,31 +830,26 @@ private void prefixExpression(ExpDesc v) throws CompileException, LuaError, Unwi expression(v); checkMatch(')', '(', line); fs.dischargeVars(v); - return; } case TK_NAME -> { singleVar(v); - return; } default -> throw syntaxError("unexpected symbol"); } } - private void primaryExpression(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { - /* - * primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | - * funcargs } - */ + private void suffixedExpression(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { + // suffixedexp -> primaryexp { '.' NAME | '[' exp ']' | ':' NAME funcargs | funcargs } FuncState fs = this.fs; - prefixExpression(v); + primaryExpression(v); while (true) { switch (lexer.token.token()) { case '.' -> { // field - field(v); + fieldSelect(v); } case '[' -> { // `[' exp1 `]' ExpDesc key = new ExpDesc(); - fs.exp2AnyReg(v); + fs.exp2AnyRegUp(v); long indexPos = lexer.token.position(); yindex(key); fs.indexed(v, key, indexPos); @@ -672,29 +875,20 @@ private void primaryExpression(ExpDesc v) throws CompileException, LuaError, Unw private void simpleExpression(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { /* * simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | - * FUNCTION body | primaryexp + * FUNCTION body | suffixedexp */ switch (lexer.token.token()) { case TK_NUMBER -> { v.init(ExpKind.VKNUM, 0); v.setNval(lexer.token.numberContents()); } - case TK_STRING -> { - codeString(v, lexer.token.stringContents()); - } - case TK_NIL -> { - v.init(ExpKind.VNIL, 0); - } - case TK_TRUE -> { - v.init(ExpKind.VTRUE, 0); - } - case TK_FALSE -> { - v.init(ExpKind.VFALSE, 0); - } + case TK_STRING -> codeString(v, lexer.token.stringContents()); + case TK_NIL -> v.init(ExpKind.VNIL, 0); + case TK_TRUE -> v.init(ExpKind.VTRUE, 0); + case TK_FALSE -> v.init(ExpKind.VFALSE, 0); case TK_DOTS -> { /* vararg */ FuncState fs = this.fs; - checkCondition(fs.varargFlags != 0, "cannot use " + LUA_QL("...") + " outside a vararg function"); - fs.varargFlags &= ~Lua.VARARG_NEEDSARG; // don't need 'arg' + checkCondition(fs.isVararg, "cannot use " + LUA_QL("...") + " outside a vararg function"); v.init(ExpKind.VVARARG, fs.codeABC(Lua.OP_VARARG, 0, 1, 0)); } case '{' -> { /* constructor */ @@ -707,7 +901,7 @@ private void simpleExpression(ExpDesc v) throws CompileException, LuaError, Unwi return; } default -> { - primaryExpression(v); + suffixedExpression(v); return; } } @@ -759,26 +953,15 @@ private void expression(ExpDesc v) throws CompileException, LuaError, UnwindThro ** ======================================================================= */ - private static boolean blockFollow(int token) { - return switch (token) { - case TK_ELSE, TK_ELSEIF, TK_END, TK_UNTIL, TK_EOS -> true; - default -> false; - }; - } - private void block() throws CompileException, LuaError, UnwindThrowable { /* block -> chunk */ - FuncState fs = this.fs; - FuncState.BlockCnt bl = new FuncState.BlockCnt(); - enterBlock(fs, bl, false); - chunk(); - LuaC._assert(bl.breaklist.value == NO_JUMP); + enterBlock(fs, false); + statementList(); leaveBlock(fs); } /* - ** structure to chain all variables in the left-hand side of an - ** assignment + ** Structure to chain all variables in the left-hand side of an assignment */ static class LhsAssign { final LhsAssign prev; @@ -790,10 +973,10 @@ static class LhsAssign { } /* - ** check whether, in an assignment to a local variable, the local variable - ** is needed in a previous assignment (to a table). If so, save original - ** local value in a safe place and use this safe copy in the previous - ** assignment. + * Check whether, in an assignment to a local variable, the local variable + * is needed in a previous assignment (to a table). If so, save original + * local value in a safe place and use this safe copy in the previous + * assignment. */ private void checkConflict(LhsAssign lh, ExpDesc v) throws CompileException { FuncState fs = this.fs; @@ -801,18 +984,24 @@ private void checkConflict(LhsAssign lh, ExpDesc v) throws CompileException { boolean conflict = false; for (; lh != null; lh = lh.prev) { if (lh.v.kind == ExpKind.VINDEXED) { - if (lh.v.info == v.info) { // conflict? + // table is the upvalue/local being assigned now? + if (lh.v.tableType == v.kind && lh.v.info == v.info) { conflict = true; + lh.v.tableType = ExpKind.VLOCAL; lh.v.info = extra; // previous assignment will use safe copy } - if (lh.v.aux == v.info) { // conflict? + + // index is the local being assigned? (index cannot be upvalue) + if (v.kind == ExpKind.VLOCAL && lh.v.aux == v.info) { conflict = true; lh.v.aux = extra; // previous assignment will use safe copy } } } + if (conflict) { - fs.codeABC(Lua.OP_MOVE, fs.freeReg, v.info, 0); /* make copy */ + var opcode = v.kind == ExpKind.VLOCAL ? OP_MOVE : OP_GETUPVAL; + fs.codeABC(opcode, extra, v.info, 0); /* make copy */ fs.reserveRegs(1); } } @@ -822,15 +1011,14 @@ private void assignment(LhsAssign lh, int nvars) throws CompileException, LuaErr checkCondition(lh.v.kind.isVar(), "syntax error"); if (testNext(',')) { // assignment -> `,' primaryexp assignment LhsAssign nv = new LhsAssign(lh); - primaryExpression(nv.v); - if (nv.v.kind == ExpKind.VLOCAL) checkConflict(lh, nv.v); + suffixedExpression(nv.v); + if (nv.v.kind != ExpKind.VINDEXED) checkConflict(lh, nv.v); assignment(nv, nvars + 1); } else { /* assignment . `=' explist1 */ checkNext('='); int nexps = expList1(e); if (nexps != nvars) { adjustAssign(nvars, nexps, e); - if (nexps > nvars) fs.freeReg -= nexps - nvars; // remove extra values } else { fs.setOneRet(e); // close last expression fs.storeVar(lh.v, e); @@ -850,18 +1038,53 @@ private int cond() throws CompileException, LuaError, UnwindThrowable { return v.f.value; } - private void breakStmt() throws CompileException { - FuncState fs = this.fs; - FuncState.BlockCnt bl = fs.block; - boolean upval = false; - while (bl != null && !bl.isbreakable) { - upval |= bl.upval; - bl = bl.previous; + private void gotoLabel(int pc, int line, LuaString label) throws CompileException { + int g = newLabelEntry(pendingGotos, label, line, pc); + findLabel(g); // Link if label already defined. + } + + private void breakStmt(int pc) throws CompileException, LuaError, UnwindThrowable { + int line = lexer.lastLine(); + lexer.nextToken(); + gotoLabel(pc, line, lexer.newString("break")); + } + + private void gotoStat(int pc) throws CompileException, LuaError, UnwindThrowable { + int line = lexer.lastLine(); + lexer.nextToken(); + gotoLabel(pc, line, strCheckName()); + } + + /** + * Check for repeated labels on the same block. + * + * @param name The name of the label. + */ + private void checkRepeated(LuaString name) throws CompileException { + for (int i = fs.block.firstLabel; i < activeLabels.size(); i++) { + var label = activeLabels.get(i); + if (label.name == name) throw semError("label '" + name + "' already defined on line " + label.line); } - if (bl == null) throw syntaxError("no loop to break"); + } - if (upval) fs.codeABC(OP_CLOSE, bl.nactvar, 0, 0); - fs.concat(bl.breaklist, fs.jump()); + /** + * Skip no-op statements + */ + private void skipNoOpStatements() throws CompileException, LuaError, UnwindThrowable { + while (lexer.token.token() == ';' || lexer.token.token() == TK_DBCOLON) statement(); + } + + private void labelStat(LuaString name, int line) throws CompileException, LuaError, UnwindThrowable { + // label -> '::' NAME '::' + checkRepeated(name); + checkNext(TK_DBCOLON); + var label = newLabelEntry(activeLabels, name, line, fs.getLabel()); + skipNoOpStatements(); + if (blockFollow(false)) { + // If label is last statement in the block, assume locals are already out of scope. + activeLabels.get(label).activeVariables = fs.block.activeVariableCount; + } + findGotos(activeLabels.get(label)); } private void whileStmt() throws CompileException, LuaError, UnwindThrowable { @@ -893,19 +1116,13 @@ private void repeatStmt() throws CompileException, LuaError, UnwindThrowable { FuncState.BlockCnt bl2 = new FuncState.BlockCnt(); enterBlock(fs, bl1, true); /* loop block */ enterBlock(fs, bl2, false); /* scope block */ - chunk(); + statementList(); checkMatch(TK_UNTIL, TK_REPEAT, line); int condexit = cond(); // read condition (inside scope block) - if (!bl2.upval) { // no upvalues? - leaveBlock(fs); // finish scope - fs.patchList(condexit, repeatInit); // close the loop - } else { /* complete semantics when there are upvalues */ - breakStmt(); /* if condition then break */ - fs.patchToHere(condexit); /* else... */ - leaveBlock(fs); /* finish scope... */ - fs.patchList(fs.jump(), repeatInit); /* and repeat */ - } - leaveBlock(fs); // finish loop */ + if (bl2.upval) fs.patchClose(condexit, bl2.activeVariableCount); // Close upvalues + leaveBlock(fs); // finish scope */ + fs.patchList(condexit, repeatInit); /* close the loop */ + leaveBlock(fs); /* finish loop */ } private void exp1() throws CompileException, LuaError, UnwindThrowable { @@ -927,20 +1144,24 @@ private void forBody(int base, long position, int nvars, boolean isNum) throws C block(); leaveBlock(fs); /* end of scope for declared variables */ fs.patchToHere(prep); - int endFor = isNum - ? fs.codeAsBxAt(Lua.OP_FORLOOP, base, NO_JUMP, position) - : fs.codeABCAt(Lua.OP_TFORLOOP, base, 0, nvars, position); - fs.patchList(isNum ? endFor : fs.jump(), prep + 1); + int endFor; + if (isNum) { + endFor = fs.codeAsBxAt(OP_FORLOOP, base, NO_JUMP, position); + } else { + fs.codeABCAt(OP_TFORCALL, base, 0, nvars, position); + endFor = fs.codeAsBxAt(OP_TFORLOOP, base + 2, NO_JUMP, position); + } + fs.patchList(endFor, prep + 1); } private void forNum(LuaString varName, long position) throws CompileException, LuaError, UnwindThrowable { /* fornum -> NAME = exp1,exp1[,exp1] forbody */ FuncState fs = this.fs; int base = fs.freeReg; - newLocal("(for index)", 0); - newLocal("(for limit)", 1); - newLocal("(for step)", 2); - newLocal(varName, 3); + newLocal("(for index)"); + newLocal("(for limit)"); + newLocal("(for step)"); + newLocal(varName); checkNext('='); exp1(); /* initial value */ checkNext(','); @@ -948,7 +1169,7 @@ private void forNum(LuaString varName, long position) throws CompileException, L if (testNext(',')) { exp1(); /* optional step */ } else { /* default step = 1 */ - fs.codeABx(Lua.OP_LOADK, fs.freeReg, fs.numberK(LuaInteger.valueOf(1))); + fs.codeK(fs.freeReg, fs.numberK(LuaInteger.valueOf(1))); fs.reserveRegs(1); } forBody(base, position, 1, true); @@ -958,15 +1179,18 @@ private void forList(LuaString indexName) throws CompileException, LuaError, Unw /* forlist -> NAME {,NAME} IN explist1 forbody */ FuncState fs = this.fs; ExpDesc e = new ExpDesc(); - int nvars = 0; + int nvars = 4; int base = fs.freeReg; /* create control variables */ - newLocal("(for generator)", nvars++); - newLocal("(for state)", nvars++); - newLocal("(for control)", nvars++); + newLocal("(for generator)"); + newLocal("(for state)"); + newLocal("(for control)"); /* create declared variables */ - newLocal(indexName, nvars++); - while (testNext(',')) newLocal(strCheckName(), nvars++); + newLocal(indexName); + while (testNext(',')) { + newLocal(strCheckName()); + nvars++; + } checkNext(TK_IN); long position = lexer.token.position(); adjustAssign(3, expList1(e), e); @@ -992,13 +1216,51 @@ private void forStmt() throws CompileException, LuaError, UnwindThrowable { leaveBlock(fs); // loop scope (`break' jumps to this point) } - private int testThenBlock() throws CompileException, LuaError, UnwindThrowable { + private void testThenBlock(IntPtr escapeList) throws CompileException, LuaError, UnwindThrowable { // test_then_block -> [IF | ELSEIF] cond THEN block lexer.nextToken(); // skip IF or ELSEIF - int condExit = cond(); + + // Read condition; + ExpDesc v = new ExpDesc(); + expression(v); + checkNext(TK_THEN); - block(); // `then' part - return condExit; + + int jumpFalse; + if (lexer.token.token() == TK_BREAK || (lexer.token.token() == TK_NAME && tryGoto())) { + fs.goIfFalse(v); // Jump to label if condition is true + enterBlock(fs, false); + if (lexer.token.token() == TK_BREAK) { + breakStmt(v.t.value); + } else { + gotoStat(v.t.value); + } + + while (testNext(';')) ; // Skip semicolons + + // If the block is just some goto, we can bail immediately + if (blockFollow(false)) { + leaveBlock(fs); + return; + } else { + // Otherwise we need to skip over the rest of the body. + jumpFalse = fs.jump(); + } + } else { + fs.goIfTrue(v); // Skip over block if condition is false + enterBlock(fs, false); + jumpFalse = v.f.value; + } + + statementList(); // `then' part + leaveBlock(fs); + + // If followed by an else/elseif, emit something to jump over that block + if (lexer.token.token() == TK_ELSE || lexer.token.token() == TK_ELSEIF) { + fs.concat(escapeList, fs.jump()); + } + + fs.patchToHere(jumpFalse); } private void ifStat() throws CompileException, LuaError, UnwindThrowable { @@ -1007,37 +1269,22 @@ private void ifStat() throws CompileException, LuaError, UnwindThrowable { FuncState fs = this.fs; IntPtr escapeList = new IntPtr(NO_JUMP); - int flist = testThenBlock(); /* IF cond THEN block */ - while (lexer.token.token() == TK_ELSEIF) { - fs.concat(escapeList, fs.jump()); - fs.patchToHere(flist); - flist = testThenBlock(); /* ELSEIF cond THEN block */ - } - if (lexer.token.token() == TK_ELSE) { - fs.concat(escapeList, fs.jump()); - fs.patchToHere(flist); - /* skip ELSE (after patch, for correct line info) */ - lexer.nextToken(); - block(); /* `else' part */ - } else { - fs.concat(escapeList, flist); - } - fs.patchToHere(escapeList.value); + testThenBlock(escapeList); /* IF cond THEN block */ + while (lexer.token.token() == TK_ELSEIF) testThenBlock(escapeList); // ELSEIF cond THEN block */ + if (testNext(TK_ELSE)) block(); // `else' part + checkMatch(TK_END, TK_IF, line); + fs.patchToHere(escapeList.value); // Patch list to "if" end. } private void localFunc() throws CompileException, LuaError, UnwindThrowable { - ExpDesc v = new ExpDesc(); FuncState fs = this.fs; - newLocal(strCheckName(), 0); - v.init(ExpKind.VLOCAL, fs.freeReg); - fs.reserveRegs(1); + newLocal(strCheckName()); adjustLocalVars(1); + ExpDesc b = new ExpDesc(); body(b, false, lexer.token.line()); - fs.storeVar(v, b); - // debug information will only see the variable after this point! - fs.getLocal(fs.activeVariableCount - 1).startpc = fs.pc; + getLocal(fs, b.info).startpc = fs.pc; } private void localStmt() throws CompileException, LuaError, UnwindThrowable { @@ -1045,7 +1292,8 @@ private void localStmt() throws CompileException, LuaError, UnwindThrowable { int nvars = 0; ExpDesc e = new ExpDesc(); do { - newLocal(strCheckName(), nvars++); + newLocal(strCheckName()); + nvars++; } while (testNext(',')); int nexps; @@ -1062,12 +1310,12 @@ private void localStmt() throws CompileException, LuaError, UnwindThrowable { private boolean funcName(ExpDesc v) throws CompileException, LuaError, UnwindThrowable { // funcname -> NAME {field} [`:' NAME] singleVar(v); - while (lexer.token.token() == '.') field(v); + while (lexer.token.token() == '.') fieldSelect(v); boolean needSelf = false; if (lexer.token.token() == ':') { needSelf = true; - field(v); + fieldSelect(v); } return needSelf; } @@ -1088,11 +1336,13 @@ private void exprStmt() throws CompileException, LuaError, UnwindThrowable { // stat -> func | assignment FuncState fs = this.fs; LhsAssign v = new LhsAssign(null); - primaryExpression(v.v); - if (v.v.kind == ExpKind.VCALL) { // stat -> func - fs.code[v.v.info] = LuaC.SETARG_C(fs.code[v.v.info], 1); // call statement uses no results - } else { // stat -> assignment + suffixedExpression(v.v); + if (lexer.token.token() == '=' || lexer.token.token() == ',') { // stat -> assignment assignment(v, 1); + } else if (v.v.kind == ExpKind.VCALL) { // stat -> func + fs.code[v.v.info] = LuaC.SETARG_C(fs.code[v.v.info], 1); // call statement uses no results + } else { + throw syntaxError("syntax error"); } } @@ -1101,7 +1351,7 @@ private void returnStmt() throws CompileException, LuaError, UnwindThrowable { FuncState fs = this.fs; int first, nret; // registers with returned values lexer.nextToken(); // skip RETURN - if (blockFollow(lexer.token.token()) || lexer.token.token() == ';') { + if (blockFollow(true) || lexer.token.token() == ';') { first = nret = 0; // return no values } else { ExpDesc e = new ExpDesc(); @@ -1109,8 +1359,8 @@ private void returnStmt() throws CompileException, LuaError, UnwindThrowable { if (e.kind.hasMultiRet()) { fs.setMultiRet(e); if (e.kind == ExpKind.VCALL && nret == 1) { /* tail call? */ - int op = fs.code[e.info] = LuaC.SET_OPCODE(fs.code[e.info], Lua.OP_TAILCALL); - LuaC._assert(Lua.GETARG_A(op) == fs.activeVariableCount); + int op = fs.code[e.info] = LuaC.SET_OPCODE(fs.code[e.info], OP_TAILCALL); + assert GETARG_A(op) == fs.activeVariableCount; } first = fs.activeVariableCount; nret = Lua.LUA_MULTRET; /* return all values */ @@ -1125,37 +1375,33 @@ private void returnStmt() throws CompileException, LuaError, UnwindThrowable { } } fs.ret(first, nret); + testNext(';'); + } + + private boolean tryGoto() throws CompileException, LuaError, UnwindThrowable { + assert lexer.token.token() == TK_NAME; + if (lexer.token.stringContents() != gotoName) return false; + + lexer.lookahead(); + return lexer.lookahead.token() == TK_NAME; } - private boolean statement() throws CompileException, LuaError, UnwindThrowable { + private void statement() throws CompileException, LuaError, UnwindThrowable { + enterLevel(); + switch (lexer.token.token()) { - case TK_IF -> { // stat -> ifstat - ifStat(); - return false; - } - case TK_WHILE -> { /* stat -> whiles-tat */ - whileStmt(); - return false; - } + case ';' -> lexer.nextToken(); + case TK_IF -> ifStat(); // stat -> ifstat + case TK_WHILE -> whileStmt(); // stat -> whiles-tat case TK_DO -> { /* stat -> DO block END */ int line = lexer.token.line(); // may be needed for error messages lexer.nextToken(); // skip DO block(); checkMatch(TK_END, TK_DO, line); - return false; - } - case TK_FOR -> { /* stat -> forstat */ - forStmt(); - return false; - } - case TK_REPEAT -> { /* stat -> repeatstat */ - repeatStmt(); - return false; - } - case TK_FUNCTION -> { - funcStmt(); /* stat -> funcstat */ - return false; } + case TK_FOR -> forStmt(); // stat -> forstat + case TK_REPEAT -> repeatStmt(); // stat -> repeatstat + case TK_FUNCTION -> funcStmt(); // stat -> funcstat case TK_LOCAL -> { /* stat -> localstat */ lexer.nextToken(); // skip LOCAL if (testNext(TK_FUNCTION)) { // local function? @@ -1163,36 +1409,49 @@ private boolean statement() throws CompileException, LuaError, UnwindThrowable { } else { localStmt(); } - return false; - } - case TK_RETURN -> { /* stat -> retstat */ - returnStmt(); - return true; // must be last statement } - case TK_BREAK -> { /* stat -> breakstat */ - lexer.nextToken(); // skip BREAK - breakStmt(); - return true; // must be last statement + case TK_DBCOLON -> { + lexer.nextToken(); // skip :: + int line = lexer.lastLine(); + labelStat(strCheckName(), line); } - default -> { - exprStmt(); - return false; + case TK_RETURN -> returnStmt(); /* stat -> retstat */ + + case TK_BREAK -> breakStmt(fs.jump()); /* stat -> breakstat */ + + case TK_NAME -> { + if (tryGoto()) { + gotoStat(fs.jump()); + } else { + exprStmt(); + } } - } - } - void chunk() throws CompileException, LuaError, UnwindThrowable { - /* chunk -> { stat [`;'] } */ - boolean islast = false; - enterLevel(); - while (!islast && !blockFollow(lexer.token.token())) { - islast = statement(); - testNext(';'); - assert fs.maxStackSize >= fs.freeReg && fs.freeReg >= fs.activeVariableCount; - fs.freeReg = fs.activeVariableCount; /* free registers */ + default -> exprStmt(); } + + assert fs.maxStackSize >= fs.freeReg && fs.freeReg >= fs.activeVariableCount; + fs.freeReg = fs.activeVariableCount; /* free registers */ + leaveLevel(); } /* }====================================================================== */ + + public Prototype mainFunction() throws CompileException, LuaError, UnwindThrowable { + FuncState funcstate = openFunc(); + funcstate.isVararg = true; /* main func. is always vararg */ + + var v = new ExpDesc(); + v.init(ExpKind.VLOCAL, 0); + newUpvalue(funcstate, envName, v); + lexer.nextToken(); // read first token + statementList(); + check(Lex.TK_EOS); + + Prototype prototype = closeFunc(); + assert funcstate.upvalues.size() == 1; + assert fs == null; + return prototype; + } } diff --git a/src/main/java/org/squiddev/cobalt/debug/DebugHelpers.java b/src/main/java/org/squiddev/cobalt/debug/DebugHelpers.java index 6301c538..00ab5404 100644 --- a/src/main/java/org/squiddev/cobalt/debug/DebugHelpers.java +++ b/src/main/java/org/squiddev/cobalt/debug/DebugHelpers.java @@ -194,21 +194,16 @@ private static ObjectName fromMetamethod(String name) { int i = p.code[pc]; switch (GET_OPCODE(i)) { - case OP_GETGLOBAL -> { - int g = GETARG_Bx(i); /* global index */ - // lua_assert(p.k[g].isString()); - LuaValue value = p.constants[g]; - LuaString string = OperationHelper.toStringDirect(value); - return new ObjectName(string, GLOBAL); - } case OP_MOVE -> { int a = GETARG_A(i); int b = GETARG_B(i); /* move from `b' to `a' */ if (b < a) return getObjectName(di, b); /* get name for `b' */ } - case OP_GETTABLE -> { - int k = GETARG_C(i); /* key index */ - return new ObjectName(constantName(p, k), FIELD); + case OP_GETTABUP, OP_GETTABLE -> { + int t = GETARG_B(i); + LuaString table = GET_OPCODE(i) == OP_GETTABUP ? p.getUpvalueName(t) : p.getLocalName(t + 1, pc); + int c = GETARG_C(i); /* key index */ + return new ObjectName(constantName(p, c), Objects.equals(table, Constants.ENV) ? GLOBAL : FIELD); } case OP_GETUPVAL -> { int u = GETARG_B(i); /* upvalue index */ @@ -240,7 +235,7 @@ private static int findSetReg(Prototype pt, int lastpc, int reg) { int b = GETARG_B(i); if (a <= reg && reg <= a + b) setreg = filterPc(pc, jumpTarget); } - case OP_TFORLOOP -> { + case OP_TFORCALL -> { if (a >= a + 2) setreg = filterPc(pc, jumpTarget); } case OP_CALL, OP_TAILCALL -> { diff --git a/src/main/java/org/squiddev/cobalt/function/HasEnvironment.java b/src/main/java/org/squiddev/cobalt/function/HasEnvironment.java deleted file mode 100644 index 5f7e4b40..00000000 --- a/src/main/java/org/squiddev/cobalt/function/HasEnvironment.java +++ /dev/null @@ -1,2 +0,0 @@ -package org.squiddev.cobalt.function;public interface HasEnvironment { -} diff --git a/src/main/java/org/squiddev/cobalt/function/LibFunction.java b/src/main/java/org/squiddev/cobalt/function/LibFunction.java index 7d596e35..0b963d89 100644 --- a/src/main/java/org/squiddev/cobalt/function/LibFunction.java +++ b/src/main/java/org/squiddev/cobalt/function/LibFunction.java @@ -97,9 +97,6 @@ * data it needs to and place it into the environment if needed. * In this case, it creates two function, 'sinh', and 'cosh', and puts * them into a global table called 'hyperbolic.' - * It placed the library table into the globals via the {@link #getfenv()} - * local variable which corresponds to the globals that apply when the - * library is loaded. *

* To test it, a script such as this can be used: *

 {@code
diff --git a/src/main/java/org/squiddev/cobalt/function/LuaClosure.java b/src/main/java/org/squiddev/cobalt/function/LuaClosure.java
index a26a0274..4d0098bf 100644
--- a/src/main/java/org/squiddev/cobalt/function/LuaClosure.java
+++ b/src/main/java/org/squiddev/cobalt/function/LuaClosure.java
@@ -24,7 +24,6 @@
  */
 package org.squiddev.cobalt.function;
 
-import org.squiddev.cobalt.LuaTable;
 import org.squiddev.cobalt.Prototype;
 import org.squiddev.cobalt.debug.Upvalue;
 
@@ -32,10 +31,6 @@
  * A lua function that provides a coroutine.
  */
 public abstract class LuaClosure extends LuaFunction {
-	public LuaClosure(LuaTable env) {
-		super(env);
-	}
-
 	/**
 	 * Get the prototype for this closure
 	 *
diff --git a/src/main/java/org/squiddev/cobalt/function/LuaFunction.java b/src/main/java/org/squiddev/cobalt/function/LuaFunction.java
index b7dbc3af..6de3855c 100644
--- a/src/main/java/org/squiddev/cobalt/function/LuaFunction.java
+++ b/src/main/java/org/squiddev/cobalt/function/LuaFunction.java
@@ -24,7 +24,6 @@
  */
 package org.squiddev.cobalt.function;
 
-import org.checkerframework.checker.nullness.qual.Nullable;
 import org.squiddev.cobalt.*;
 
 /**
@@ -40,17 +39,10 @@
  * @see LuaInterpretedFunction
  */
 public abstract class LuaFunction extends LuaValue {
-	private @Nullable LuaTable env;
-
 	public LuaFunction() {
 		super(Constants.TFUNCTION);
 	}
 
-	public LuaFunction(LuaTable env) {
-		super(Constants.TFUNCTION);
-		this.env = env;
-	}
-
 	@Override
 	public final LuaFunction checkFunction() {
 		return this;
@@ -61,17 +53,6 @@ public final LuaTable getMetatable(LuaState state) {
 		return state.functionMetatable;
 	}
 
-	@Override
-	public final LuaTable getfenv() {
-		return env;
-	}
-
-	@Override
-	public final boolean setfenv(LuaTable env) {
-		this.env = env;
-		return true;
-	}
-
 	public abstract String debugName();
 
 	/**
diff --git a/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java b/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java
index 0c54b360..b9e12a3b 100644
--- a/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java
+++ b/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java
@@ -78,7 +78,6 @@
  * Since a {@link LuaInterpretedFunction} is a {@link LuaFunction} which is a {@link LuaValue},
  * all the value operations can be used directly such as:
  *