Skip to content

Commit

Permalink
Poll interrupted state when reading from InputStreams
Browse files Browse the repository at this point in the history
This prevents hard aborts when reading incredibly long (1GB+ strings).
Ideally we wouldn't even be able to create strings that long, but that's
a whole 'nother issue!
  • Loading branch information
SquidDev committed Apr 4, 2024
1 parent e9f6a37 commit 97d2f36
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ private val subclassRelations = mapOf(
UnorderedPair("org/squiddev/cobalt/LuaValue", "org/squiddev/cobalt/LuaString") to "org/squiddev/cobalt/LuaValue",
UnorderedPair("org/squiddev/cobalt/Varargs", "org/squiddev/cobalt/LuaValue") to "org/squiddev/cobalt/Varargs",
UnorderedPair("org/squiddev/cobalt/LuaError", "org/squiddev/cobalt/compiler/CompileException") to "java/lang/Exception",
UnorderedPair("org/squiddev/cobalt/LuaError", "java/lang/Exception") to "java/lang/Exception",
UnorderedPair("org/squiddev/cobalt/compiler/CompileException", "java/lang/Exception") to "java/lang/Exception"
)

/** A [ClassWriter] extension which avoids loading classes when computing frames. */
Expand Down
31 changes: 26 additions & 5 deletions src/main/java/org/squiddev/cobalt/compiler/InputReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,31 @@
/**
* A basic byte-by-byte input stream, which can yield when reading.
*/
public interface InputReader {
int read() throws CompileException, LuaError, UnwindThrowable;

default int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable {
throw new IllegalStateException("Cannot resume a non-yielding InputReader.");
public abstract class InputReader {
protected InputReader() {
}

/**
* Read a single byte from this input.
*
* @return The read byte.
* @throws LuaError If the underlying reader threw a Lua error.
* @throws CompileException If reading failed for some other reason. Unlike a {@link LuaError}, this will not be
* passed to the {@code xpcall} error handler.
* @throws UnwindThrowable If the reader yielded. {@link #resume(Varargs)} will be called when the coroutine is
* resumed.
*/
public abstract int read() throws CompileException, LuaError, UnwindThrowable;

/**
* Resume this reader after yielding.
*
* @param varargs The value returned from the function above this in the stack
* @return The read byte.
* @throws LuaError If the underlying reader threw a Lua error.
* @throws CompileException If reading failed for some other reason. Unlike a {@link LuaError}, this will not be
* passed to the {@code xpcall} error handler.
* @throws UnwindThrowable If the reader yielded.
*/
public abstract int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable;
}
31 changes: 2 additions & 29 deletions src/main/java/org/squiddev/cobalt/compiler/LoadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
* Class to manage loading of {@link Prototype} instances.
* <p>
* The {@link LoadState} class exposes one main function,
* namely {@link #load(LuaState, InputStream, LuaString, LuaTable)},
* namely {@link #load(LuaState, InputStream, LuaString, LuaValue)},
* to be used to load code from a particular input stream.
* <p>
* A simple pattern for loading and executing code is
Expand All @@ -60,11 +60,6 @@
* @see LuaC
*/
public final class LoadState {
/**
* Name for compiled chunks
*/
private static final LuaString SOURCE_BINARY_STRING = valueOf("=?");

/**
* Construct our standard Lua function.
*/
Expand Down Expand Up @@ -108,29 +103,7 @@ public static LuaClosure load(LuaState state, InputStream stream, String name, L
* @throws CompileException If the stream cannot be loaded.
*/
public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaValue env) throws CompileException, LuaError {
return load(state, stream, name, null, env);
}

public static LuaClosure load(LuaState state, InputStream stream, LuaString name, LuaString mode, LuaValue env) throws CompileException, LuaError {
return state.compiler.load(LuaC.compile(state, stream, name, mode), env);
}

/**
* Construct a source name from a supplied chunk name
*
* @param name String name that appears in the chunk
* @return source file name
*/
static LuaString getSourceName(LuaString name) {
if (name.length() > 0) {
return switch (name.charAt(0)) {
case '@', '=' -> name.substring(1);
case 27 -> SOURCE_BINARY_STRING;
default -> name;
};
}

return name;
return state.compiler.load(LuaC.compile(state, stream, name), env);
}

private static final int NAME_LENGTH = 30;
Expand Down
30 changes: 25 additions & 5 deletions src/main/java/org/squiddev/cobalt/compiler/LuaC.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import cc.tweaked.cobalt.internal.unwind.AutoUnwind;
import cc.tweaked.cobalt.internal.unwind.SuspendedAction;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.squiddev.cobalt.*;
import org.squiddev.cobalt.compiler.LoadState.FunctionFactory;
import org.squiddev.cobalt.function.LuaInterpretedFunction;
Expand Down Expand Up @@ -145,13 +146,13 @@ private LuaC() {
* @throws CompileException If there is a syntax error.
*/
public static Prototype compile(LuaState state, InputStream stream, String name) throws CompileException, LuaError {
return compile(state, stream, valueOf(name), null);
return compile(state, stream, valueOf(name));
}

public static Prototype compile(LuaState state, InputStream stream, LuaString name, LuaString mode) throws CompileException, LuaError {
public static Prototype compile(LuaState state, InputStream stream, LuaString name) throws CompileException, LuaError {
Object result = SuspendedAction.noYield(() -> {
try {
return compile(state, new InputStreamReader(stream), name, mode);
return compile(state, new InputStreamReader(stream), name, null);
} catch (CompileException e) {
return e;
}
Expand Down Expand Up @@ -186,15 +187,34 @@ private static Prototype loadTextChunk(int firstByte, InputReader stream, LuaStr
return parser.mainFunction();
}

public record InputStreamReader(InputStream stream) implements InputReader {
public static final class InputStreamReader extends InputReader {
private final @Nullable LuaState state;
private final InputStream stream;

public InputStreamReader(InputStream stream) {
this(null, stream);
}

public InputStreamReader(@Nullable LuaState state, InputStream stream) {
this.state = state;
this.stream = stream;
}

@Override
public int read() throws CompileException {
public int read() throws CompileException, UnwindThrowable, LuaError {
if (state != null && state.isInterrupted()) state.handleInterrupt();

try {
return stream.read();
} catch (IOException e) {
String message = e.getMessage() == null ? e.toString() : e.getMessage();
throw new CompileException("io error: " + message);
}
}

@Override
public int resume(Varargs varargs) throws CompileException, LuaError, UnwindThrowable {
return read();
}
}
}
49 changes: 28 additions & 21 deletions src/main/java/org/squiddev/cobalt/lib/BaseLib.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
import org.squiddev.cobalt.*;
import org.squiddev.cobalt.compiler.CompileException;
import org.squiddev.cobalt.compiler.InputReader;
import org.squiddev.cobalt.compiler.LoadState;
import org.squiddev.cobalt.compiler.LuaC;
import org.squiddev.cobalt.debug.DebugFrame;
import org.squiddev.cobalt.function.*;
import org.squiddev.cobalt.unwind.SuspendedTask;

import java.io.InputStream;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -69,7 +69,7 @@ public static void add(LuaTable env) {
RegisteredFunction.ofV("assert", BaseLib::assert_),
RegisteredFunction.of("getfenv", BaseLib::getfenv),
RegisteredFunction.ofV("getmetatable", BaseLib::getmetatable),
RegisteredFunction.ofV("loadstring", BaseLib::loadstring),
RegisteredFunction.ofS("loadstring", BaseLib::loadstring),
RegisteredFunction.ofV("select", BaseLib::select),
RegisteredFunction.ofV("type", BaseLib::type),
RegisteredFunction.ofV("rawequal", BaseLib::rawequal),
Expand Down Expand Up @@ -138,10 +138,11 @@ private static LuaValue getmetatable(LuaState state, Varargs args) throws LuaErr
return mt != null ? mt.rawget(Constants.METATABLE).optValue(mt) : Constants.NIL;
}

private static Varargs loadstring(LuaState state, Varargs args) throws LuaError {
private static Varargs loadstring(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable {
// loadstring( string [,chunkname] ) -> chunk | nil, msg
LuaString script = args.arg(1).checkLuaString();
return BaseLib.loadStream(state, script.toInputStream(), args.arg(2).optLuaString(script));
InputStream is = script.toInputStream();
return loadStream(state, di, is, args.arg(2).optLuaString(script), null, state.globals());
}

private static Varargs select(LuaState state, Varargs args) throws LuaError {
Expand Down Expand Up @@ -311,7 +312,7 @@ public Varargs resumeError(LuaState state, ProtectedCall call, LuaError error) t
}

// load( func|str [,chunkname[, mode[, env]]] ) -> chunk | nil, msg
static class Load extends ResumableVarArgFunction<ProtectedCall> {
static class Load extends ResumableVarArgFunction<Object> {
@Override
protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable {
LuaValue scriptGen = args.arg(1);
Expand All @@ -322,7 +323,7 @@ protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws Lua
// If we're a string, load as normal
if (scriptGen.isString()) {
LuaString contents = scriptGen.checkLuaString();
return BaseLib.loadStream(state, contents.toInputStream(), chunkName == null ? contents : chunkName, mode, funcEnv);
return BaseLib.loadStream(state, di, contents.toInputStream(), chunkName == null ? contents : chunkName, mode, funcEnv);
}

LuaFunction function = scriptGen.checkFunction();
Expand All @@ -339,29 +340,35 @@ protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws Lua
}

@Override
public Varargs resume(LuaState state, ProtectedCall call, Varargs value) throws UnwindThrowable {
return call.resume(state, value).asResultOrFailure();
public Varargs resume(LuaState state, Object funcState, Varargs value) throws UnwindThrowable, LuaError {
if (funcState instanceof ProtectedCall call) {
return call.resume(state, value).asResultOrFailure();
} else {
return ((SuspendedTask<Varargs>) funcState).resume(value);
}
}

@Override
public Varargs resumeError(LuaState state, ProtectedCall call, LuaError error) throws UnwindThrowable {
return call.resumeError(state, error).asResultOrFailure();
}
}

public static Varargs loadStream(LuaState state, InputStream is, LuaString chunkName, LuaString mode, LuaValue env) {
try {
return LoadState.load(state, is, chunkName, mode, env);
} catch (LuaError | CompileException e) {
return varargsOf(Constants.NIL, valueOf(e.getMessage()));
public Varargs resumeError(LuaState state, Object funcState, LuaError error) throws UnwindThrowable, LuaError {
if (funcState instanceof ProtectedCall call) {
return call.resumeError(state, error).asResultOrFailure();
} else {
return super.resumeError(state, funcState, error);
}
}
}

public static Varargs loadStream(LuaState state, InputStream is, LuaString chunkName) {
return loadStream(state, is, chunkName, null, state.globals());
private static Varargs loadStream(LuaState state, DebugFrame frame, InputStream is, LuaString chunkName, LuaString mode, LuaValue env) throws UnwindThrowable, LuaError {
return SuspendedAction.run(frame, () -> {
try {
return state.compiler.load(LuaC.compile(state, new LuaC.InputStreamReader(state, is), chunkName, mode), env);
} catch (CompileException e) {
return varargsOf(Constants.NIL, valueOf(e.getMessage()));
}
});
}

private static class FunctionInputReader implements InputReader {
private static class FunctionInputReader extends InputReader {
private static final ByteBuffer EMPTY = ByteBuffer.allocate(0);

private final LuaState state;
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/org/squiddev/cobalt/lib/system/SystemBaseLib.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import cc.tweaked.cobalt.internal.unwind.SuspendedAction;
import org.squiddev.cobalt.*;
import org.squiddev.cobalt.compiler.CompileException;
import org.squiddev.cobalt.compiler.LoadState;
import org.squiddev.cobalt.debug.DebugFrame;
import org.squiddev.cobalt.function.Dispatch;
import org.squiddev.cobalt.function.RegisteredFunction;
import org.squiddev.cobalt.lib.BaseLib;

import java.io.InputStream;
import java.io.PrintStream;
Expand Down Expand Up @@ -63,14 +64,14 @@ private static LuaValue collectgarbage(LuaState state, LuaValue arg1, LuaValue a
private Varargs loadfile(LuaState state, Varargs args) throws LuaError {
// loadfile( [filename] ) -> chunk | nil, msg
return args.first().isNil() ?
BaseLib.loadStream(state, in, STDIN_STR) :
SystemBaseLib.loadBasicStream(state, in, STDIN_STR) :
SystemBaseLib.loadFile(state, resources, args.arg(1).checkString());
}

private Varargs dofile(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable {
// dofile( filename ) -> result1, ...
Varargs v = args.first().isNil() ?
BaseLib.loadStream(state, in, STDIN_STR) :
SystemBaseLib.loadBasicStream(state, in, STDIN_STR) :
SystemBaseLib.loadFile(state, resources, args.arg(1).checkString());
if (v.first().isNil()) {
throw new LuaError(v.arg(2).toString());
Expand Down Expand Up @@ -99,6 +100,14 @@ private Varargs print(LuaState state, DebugFrame frame, Varargs args) throws Lua
});
}

private static Varargs loadBasicStream(LuaState state, InputStream is, LuaString chunkName) {
try {
return LoadState.load(state, is, chunkName, state.globals());
} catch (LuaError | CompileException e) {
return varargsOf(Constants.NIL, valueOf(e.getMessage()));
}
}

/**
* Load from a named file, returning the chunk or nil,error of can't load
*
Expand All @@ -112,7 +121,7 @@ public static Varargs loadFile(LuaState state, ResourceLoader resources, String
return varargsOf(Constants.NIL, valueOf("cannot open " + filename + ": No such file or directory"));
}
try {
return BaseLib.loadStream(state, is, valueOf("@" + filename));
return loadBasicStream(state, is, valueOf("@" + filename));
} finally {
try {
is.close();
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/squiddev/cobalt/ProtectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void tearDown() {

@Timeout(3)
@ParameterizedTest(name = ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER)
@ValueSource(strings = {"string", "loop"})
@ValueSource(strings = {"string", "loop", "load"})
public void run(String name) throws IOException, CompileException, LuaError, InterruptedException {
LuaThread.runMain(helpers.state, helpers.loadScript(name));
}
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/squiddev/cobalt/compiler/SimpleTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void setup() throws LuaError {
private void doTest(String script) {
try {
InputStream is = new ByteArrayInputStream(script.getBytes(StandardCharsets.UTF_8));
LuaFunction c = LoadState.interpretedFunction(LuaC.compile(state, is, valueOf("script"), null), _G);
LuaFunction c = LoadState.load(state, is, valueOf("script"), _G);
LuaThread.runMain(state, c);
} catch (Exception e) {
fail("i/o exception: " + e);
Expand Down Expand Up @@ -127,7 +127,7 @@ public void testZap() {
String s = "print('\\z";
assertThrows(CompileException.class, () -> {
InputStream is = new ByteArrayInputStream(s.getBytes(StandardCharsets.UTF_8));
LoadState.interpretedFunction(LuaC.compile(state, is, valueOf("script"), null), _G);
LoadState.load(state, is, valueOf("script"), _G);
});
}

Expand Down
13 changes: 13 additions & 0 deletions src/test/resources/protection/load.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
--- Test loading long strings

local function check(...)
local success, message = pcall(...)

assert(not success, "Expected abort")
assert(message:find("Timed out"), "Got " .. message)
end

check(function()
local fn, err = load("--[" .. ("="):rep(1e8) .. "[")
print(fn, err)
end)

0 comments on commit 97d2f36

Please sign in to comment.