Skip to content

Commit

Permalink
feat: support function script flags
Browse files Browse the repository at this point in the history
  • Loading branch information
PokIsemaine committed Jul 25, 2024
1 parent e1fefbc commit 8f290f2
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 153 deletions.
2 changes: 0 additions & 2 deletions src/cluster/cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,6 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const {
Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector<std::string> &cmd_tokens,
redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) {
std::vector<int> keys_indexes;
std::cout << "CanExecByMySelf\n";
// No keys
if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK())
return Status::OK();
Expand All @@ -851,7 +850,6 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons

bool cross_slot_ok = false;
if (script_run_ctx) {
std::cout << "Check script_run_ctx\n";
if (script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) {
if (getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) {
return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)};
Expand Down
11 changes: 11 additions & 0 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,16 @@ Status Server::ScriptSet(const std::string &sha, const std::string &body) const
return storage->WriteToPropagateCF(func_name, body);
}

void Server::CacheScriptFlags(const std::string &sha, uint64_t flags) { script_flags_cache_.try_emplace(sha, flags); }

[[nodiscard]] Status Server::GetScriptFlags(const std::string &sha, uint64_t &flags) const {
if (script_flags_cache_.count(sha)) {
flags = script_flags_cache_.at(sha);
return Status::OK();
}
return {Status::NotFound, "The flags cache of script sha does not exist, sha: " + sha};
}

Status Server::FunctionGetCode(const std::string &lib, std::string *code) const {
std::string func_name = engine::kLuaLibCodePrefix + lib;
auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate);
Expand Down Expand Up @@ -1714,6 +1724,7 @@ Status Server::FunctionSetLib(const std::string &func, const std::string &lib) c

void Server::ScriptReset() {
auto lua = lua_.exchange(lua::CreateState(this));
script_flags_cache_.clear();
lua::DestroyState(lua);
}

Expand Down
4 changes: 4 additions & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class Server {
Status ScriptSet(const std::string &sha, const std::string &body) const;
void ScriptReset();
Status ScriptFlush();
void CacheScriptFlags(const std::string &sha, uint64_t flags);
[[nodiscard]] Status GetScriptFlags(const std::string &sha, uint64_t &flags) const;

Status FunctionGetCode(const std::string &lib, std::string *code) const;
Status FunctionGetLib(const std::string &func, std::string *lib) const;
Expand Down Expand Up @@ -341,6 +343,8 @@ class Server {
std::mutex last_random_key_cursor_mu_;

std::atomic<lua_State *> lua_;
// The cache of flag is cached when the script is created and cleared when the script is flushed.
std::unordered_map<std::string, uint64_t> script_flags_cache_;

redis::Connection *curr_connection_ = nullptr;

Expand Down
171 changes: 118 additions & 53 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,6 @@ void LoadFuncs(lua_State *lua) {
lua_pcall(lua, 0, 0, 0);
}

void LoadScriptFlags(lua_State *lua, uint64_t flags) {
std::cout << "LoadScriptFlags:" << flags << '\n';
lua_getglobal(lua, "redis");
lua_pushstring(lua, "script_flags");
lua_pushinteger(lua, static_cast<lua_Integer>(flags));
lua_settable(lua, -3);
lua_pop(lua, 1);
stackDump(lua);
}
int RedisLogCommand(lua_State *lua) {
int argc = lua_gettop(lua);

Expand Down Expand Up @@ -247,6 +238,9 @@ int RedisRegisterFunction(lua_State *lua) {

// set this function to global
std::string name = lua_tostring(lua, 1);
if (argc == 3) {
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
}
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());

// set this function name to REDIS_FUNCTION_LIBRARIES[libname]
Expand Down Expand Up @@ -278,7 +272,6 @@ int RedisRegisterFunction(lua_State *lua) {
lua_pushstring(lua, "redis.register_function() failed to store informantion.");
return lua_error(lua);
}

return 0;
}

Expand All @@ -292,9 +285,8 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee
return {Status::NotOK, "Expect a Shebang statement in the first line"};
}

ShebangParser parser(first_line);
if (auto s = parser.Parse(); !s.IsOK()) return s;
auto libname = parser.GetLibName();
std::string libname;
if (auto s = ExtractLibNameFromShebang(first_line, libname); !s.IsOK()) return s;

auto srv = conn->GetServer();
auto lua = read_only ? conn->Owner()->Lua() : srv->Lua();
Expand Down Expand Up @@ -392,13 +384,43 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std:
std::string libcode;
s = srv->FunctionGetCode(libname, &libcode);
if (!s) return s;

s = FunctionLoad(conn, libcode, false, false, &libname, read_only);
if (!s) return s;

lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
}

uint64_t function_flags = read_only ? ScriptFlags::kScriptNoWrites : 0;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
// script, myfunc, err_func, table
if (!lua_isnil(lua, -1)) {
int n = static_cast<int>(lua_objlen(lua, -1));
for (int i = 1; i <= n; ++i) {
lua_pushnumber(lua, i);
lua_gettable(lua, -2);
std::string flag = lua_tostring(lua, -1);
if (flag == "no-writes") {
function_flags |= kScriptNoWrites;
} else if (flag == "allow-oom") {
return {Status::NotSupported, "allow-oom is not supported yet"};
} else if (flag == "allow-stale") {
return {Status::NotSupported, "allow-stale is not supported yet"};
} else if (flag == "no-cluster") {
function_flags |= kScriptNoCluster;
} else if (flag == "allow-cross-slot-keys") {
function_flags |= kScriptAllowCrossSlotKeys;
} else {
return {Status::NotOK, "Unexpected function flag: " + flag};
}
lua_pop(lua, 1);
}
}
lua_pop(lua, 1);

ScriptRunCtx script_run_ctx;
script_run_ctx.flags = function_flags;
SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);

PushArray(lua, keys);
PushArray(lua, argv);
if (lua_pcall(lua, 2, 1, -4)) {
Expand Down Expand Up @@ -555,6 +577,8 @@ Status FunctionDelete(Server *srv, const std::string &name) {
std::string func = lua_tostring(lua, -1);
lua_pushnil(lua);
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + func).c_str());
lua_pushnil(lua);
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + func).c_str());
auto _ = storage->Delete(rocksdb::WriteOptions(), cf, engine::kLuaFuncLibPrefix + func);
lua_pop(lua, 1);
}
Expand Down Expand Up @@ -594,8 +618,6 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh

/* Try to lookup the Lua function */
lua_getglobal(lua, funcname);
std::cout << "Try to lookup the Lua function\n";
stackDump(lua);
if (lua_isnil(lua, -1)) {
lua_pop(lua, 1); /* remove the nil from the stack */
std::string body;
Expand All @@ -608,31 +630,6 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh
} else {
body = body_or_sha;
}
std::cout << "Get Body:\n" << body;
uint64_t script_flags = read_only ? ScriptFlags::kScriptNoWrites : 0;
if (auto pos = body.find('\n'); pos != std::string::npos) {
auto first_line = body.substr(0, pos);
std::cout << "\nGet First Line:" << first_line << '\n';

if (util::HasPrefix(first_line, "#!lua")) {
ShebangParser parser(first_line);
auto s = parser.Parse();
if (!s.IsOK()) {
lua_pop(lua, 1); /* remove the error handler from the stack. */
return s;
}
script_flags |= parser.GetFlags();
} else {
// scripts without #! can run commands that access keys belonging to different cluster hash slots,
// but ones with #! inherit the default flags, so they cannot.
script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys;
}
}

ScriptRunCtx script_run_ctx;
script_run_ctx.flags = script_flags;
SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
// LoadScriptFlags(lua, script_flags);

std::string sha = funcname + 2;
auto s = CreateFunction(srv, body, &sha, lua, false);
Expand All @@ -644,6 +641,12 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh
lua_getglobal(lua, funcname);
}

ScriptRunCtx current_script_run_ctx;
auto s = srv->GetScriptFlags(funcname + 2, current_script_run_ctx.flags);
if (!s.IsOK()) return s;
if (read_only) current_script_run_ctx.flags |= ScriptFlags::kScriptNoWrites;
SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &current_script_run_ctx);

// For the Lua script, should be always run with RESP2 protocol,
// unless users explicitly set the protocol version in the script via `redis.setresp`.
// So we need to save the current protocol version and set it to RESP2,
Expand All @@ -654,12 +657,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh
* EVAL received. */
SetGlobalArray(lua, "KEYS", keys);
SetGlobalArray(lua, "ARGV", argv);
// int errfunc_index =
std::cout << "Before EvalGenericCommand lua_pcall\n";
stackDump(lua);
if (lua_pcall(lua, 0, 1, -2)) {
std::cout << "After EvalGenericCommand lua_pcall\n";
stackDump(lua);
auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1));
*output = redis::Error({Status::NotOK, msg});
lua_pop(lua, 2);
Expand All @@ -674,6 +672,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh
lua_setglobal(lua, "KEYS");
lua_pushnil(lua);
lua_setglobal(lua, "ARGV");
SaveOnRegistry<void>(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, nullptr);

/* Call the Lua garbage collector from time to time to avoid a
* full cycle performed by Lua, which adds too latency.
Expand Down Expand Up @@ -714,9 +713,9 @@ Server *GetServer(lua_State *lua) {
// TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
// so the function need to be refactored
int RedisGenericCommand(lua_State *lua, int raise_error) {
ScriptRunCtx *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
std::cout << "get script_flags = " << script_run_ctx->flags << '\n';
stackDump(lua);
auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
CHECK_NOTNULL(script_run_ctx);

int argc = lua_gettop(lua);
if (argc == 0) {
PushError(lua, "Please specify at least one argument for redis.call()");
Expand Down Expand Up @@ -777,8 +776,11 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
}
auto s = srv->cluster->CanExecByMySelf(attributes, args, conn, script_run_ctx);
if (!s.IsOK()) {
std::cout << "CanExecByMySelf failed, s = " << s.Msg() << '\n';
PushError(lua, redis::StatusToRedisErrorMsg(s).c_str());
if (s.Is<Status::RedisMoved>()) {
PushError(lua, "Script attempted to access a non local key in a cluster node script");
} else {
PushError(lua, redis::StatusToRedisErrorMsg(s).c_str());
}
return raise_error ? RaiseError(lua) : 1;
}
}
Expand Down Expand Up @@ -1356,8 +1358,6 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {
[[noreturn]] int RaiseError(lua_State *lua) {
lua_pushstring(lua, "err");
lua_gettable(lua, -2);
std::cout << "RaiseError\n";
stackDump(lua);
lua_error(lua);
__builtin_unreachable();
}
Expand Down Expand Up @@ -1488,8 +1488,73 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu
}
lua_setglobal(lua, funcname);

// Cache the flags of the current script
uint64_t script_flags = 0;
if (auto pos = body.find('\n'); pos != std::string::npos) {
auto first_line = body.substr(0, pos);
if (util::HasPrefix(first_line, "#!lua")) {
uint64_t shebang_flags = 0;
if (auto s = ExtractFlagsFromShebang(first_line, shebang_flags); !s.IsOK()) {
lua_pop(lua, 1); /* remove the error handler from the stack. */
return s;
}
script_flags |= shebang_flags;
} else {
// scripts without #! can run commands that access keys belonging to different cluster hash slots,
// but ones with #! inherit the default flags, so they cannot.
script_flags |= ScriptFlags::kScriptAllowCrossSlotKeys;
}
}
srv->CacheScriptFlags(*sha, script_flags);

// would store lua function into propagate column family and propagate those scripts to slaves
return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK();
}

[[nodiscard]] Status ExtractLibNameFromShebang(const std::string &shebang, std::string &libname) {
static constexpr const char *shebang_prefix = "#!lua";
static constexpr const char *shebang_libname_prefix = "name=";

if (!util::HasPrefix(shebang, shebang_prefix)) {
return {Status::NotOK, "Expect shebang prefix \"#!lua\" at the beginning of the first line"};
}

if (auto pos = shebang.find(shebang_libname_prefix, strlen(shebang_prefix)); pos != std::string::npos) {
libname = shebang.substr(pos + strlen(shebang_libname_prefix));
if (libname.empty() ||
std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) {
return {Status::NotOK, "Expect a valid library name in the Shebang statement"};
}
return Status::OK();
}

return {Status::NotOK, "Expect a library name in the Shebang statement"};
}

[[nodiscard]] Status ExtractFlagsFromShebang(const std::string &shebang, uint64_t &flags) {
static constexpr const char *shebang_prefix = "#!lua";
static constexpr const char *shebang_flags_prefix = "flags=";

if (auto pos = shebang.find(shebang_flags_prefix, strlen(shebang_prefix)); pos != std::string::npos) {
auto flags_content = shebang.substr(pos + strlen(shebang_flags_prefix));
flags = 0;
for (const auto &flag : util::Split(flags_content, ",")) {
if (flag == "no-writes") {
flags |= kScriptNoWrites;
} else if (flag == "allow-oom") {
return {Status::NotSupported, "allow-oom is not supported yet"};
} else if (flag == "allow-stale") {
return {Status::NotSupported, "allow-stale is not supported yet"};
} else if (flag == "no-cluster") {
flags |= kScriptNoCluster;
} else if (flag == "allow-cross-slot-keys") {
flags |= kScriptAllowCrossSlotKeys;
} else {
return {Status::NotOK, "Unexpected flag in script shebang: " + flag};
}
}
}
return Status::OK();
}

} // namespace lua
Loading

0 comments on commit 8f290f2

Please sign in to comment.