From 8142acd1c11856a2a6d41436838d7d83ce0c1d9b Mon Sep 17 00:00:00 2001 From: aaron <83140718+centau@users.noreply.github.com> Date: Sun, 6 Oct 2024 15:20:04 +0100 Subject: [PATCH] Add `context()` --- CHANGELOG.md | 4 ++ docs/api/reactivity-core.md | 6 +- docs/api/reactivity-utility.md | 41 +++++++++++++ src/context.luau | 75 ++++++++++++++++++++++++ src/graph.luau | 13 +++++ src/init.luau | 5 ++ test/benchmark.luau | 58 +++++++++++++++++++ test/tests.luau | 102 +++++++++++++++++++++++++++++++++ 8 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 src/context.luau diff --git a/CHANGELOG.md b/CHANGELOG.md index 89cf8e7..5ba48b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## Unreleased +### Added + +- `context()` + ### Fixed - Error stack traces being lost. diff --git a/docs/api/reactivity-core.md b/docs/api/reactivity-core.md index 8b860d0..93c2263 100644 --- a/docs/api/reactivity-core.md +++ b/docs/api/reactivity-core.md @@ -34,7 +34,11 @@ Creates a new source with the given value. - **Type** ```lua - function source(value: T): (T?) -> T + function source(value: T): Source + + type Source = + () -> T -- get + & (T) -> () -- set ``` - **Details** diff --git a/docs/api/reactivity-utility.md b/docs/api/reactivity-utility.md index 7f08657..824fe35 100644 --- a/docs/api/reactivity-utility.md +++ b/docs/api/reactivity-utility.md @@ -90,4 +90,45 @@ trigger effects until after the function finishes running. only cause the effect to run once after the batch call ends instead of after each time a source is updated. +## context() + +Creates a new context. + +- **Type** + + ```lua + function context(default: T): Context + + type Context = + () -> T -- get + & (T, () -> ()) -> () -- set + ``` + +- **Details** + + Calling `context()` returns a new context function. + Call this function with no arguments to get the context value. + Call this function with a value and a callback to set a new context with the + given value. + +- **Example** + + ```lua + local theme = context() + + local function Button() + print(theme()) + end + + root(function() + theme("light", function() + Button() -- prints "light" + + theme("dark", function() + Button() -- prints "dark" + end) + end) + end) + ``` + -------------------------------------------------------------------------------- diff --git a/src/context.luau b/src/context.luau new file mode 100644 index 0000000..fc0e8c1 --- /dev/null +++ b/src/context.luau @@ -0,0 +1,75 @@ +if not game then script = require "test/relative-string" end + +local throw = require(script.Parent.throw) +local graph = require(script.Parent.graph) +type Node = graph.Node +local create_node = graph.create_node +local get_scope = graph.get_scope +local push_scope = graph.push_scope +local pop_scope = graph.pop_scope +local set_context = graph.set_context + +export type Context = (() -> T) & ((T, () -> ()) -> ()) + +local nil_symbol = newproxy() +local count = 0 + +local function context(...: T): Context + count += 1 + local id = count + + local has_default = select("#", ...) > 0 + local default_value = ... + + return function(...) + local scope: Node? | false = get_scope() + + if select("#", ...) == 0 then -- get + while scope do + local ctx = scope.context + + if not ctx then + scope = scope.owner + continue + end + + local value = (ctx :: { unknown })[id] + + if value == nil then + scope = scope.owner + continue + end + + return (if value ~= nil_symbol then value else nil) :: T + end + + if has_default ~= nil then + return default_value + else + throw("attempt to get context when no context is set and no default context is set") + end + else -- set + if not scope then return throw("attempt to set context outside of a vide scope") end + + local value, component = ... + + local new_scope = create_node(scope, false, false) + set_context(new_scope, id, if value == nil then nil_symbol else value) + + push_scope(new_scope) + + local function efn(err: string) return debug.traceback(err, 3) end + local ok, result = xpcall(component, efn) + + pop_scope() + + if not ok then + throw(`error while running context:\n\n{result}`) + end + end + + return nil :: any + end +end + +return context diff --git a/src/graph.luau b/src/graph.luau index 34deff4..9648180 100644 --- a/src/graph.luau +++ b/src/graph.luau @@ -13,6 +13,8 @@ export type Node = { effect: ((T) -> T) | false, cleanups: { () -> () } | false, + context: { [number]: unknown } | false, + owned: { Node } | false, owner: Node | false, @@ -241,6 +243,8 @@ local function create_node(owner: false | Node, effect: false | (T) -> T effect = effect, cleanups = false, + context = false, + owner = owner, owned = false, @@ -266,6 +270,14 @@ local function get_children(node: Node): { Node } return { unpack(node) } :: { Node } end +local function set_context(node: Node, key: number, value: unknown) + if node.context then + node.context[key] = value + else + node.context = { [key] = value } + end +end + return table.freeze { push_scope = push_scope, pop_scope = pop_scope, @@ -283,5 +295,6 @@ return table.freeze { get_children = get_children, flush_update_queue = flush_update_queue, get_update_queue_length = get_update_queue_length, + set_context = set_context, scopes = scopes } diff --git a/src/init.luau b/src/init.luau index 840f7f9..64a4a55 100644 --- a/src/init.luau +++ b/src/init.luau @@ -16,6 +16,7 @@ local cleanup = require(script.cleanup) local untrack = require(script.untrack) local read = require(script.read) local batch = require(script.batch) +local context = require(script.context) local switch = require(script.switch) local show = require(script.show) local indexes, values = require(script.maps)() @@ -26,6 +27,9 @@ local throw = require(script.throw) local flags = require(script.flags) export type Source = source.Source +export type source = Source +export type Context = context.Context +export type context = Context local function step(dt: number) if game then @@ -63,6 +67,7 @@ local vide = { untrack = untrack, read = read, batch = batch, + context = context, -- animations spring = spring, diff --git a/test/benchmark.luau b/test/benchmark.luau index c35b651..05fe3ef 100644 --- a/test/benchmark.luau +++ b/test/benchmark.luau @@ -4,11 +4,14 @@ local BENCH, START = testkit.benchmark() local vide = require "src/init" local source = vide.source local derive = vide.derive +local effect = vide.effect local indexes = vide.indexes local values = vide.values local batch = vide.batch local cleanup = vide.cleanup +local untrack = vide.untrack local create = vide.create +local context = vide.context assert(not vide.strict) @@ -446,6 +449,61 @@ ROOT_BENCH("values() all remove", function() src(data) end) +TITLE "context()" + +ROOT_BENCH("set context", function() + local ctx = context() + + for i = 1, START(N) do + ctx(i, function() end) + end +end) + +ROOT_BENCH("get context (depth=1)", function() + local ctx = context() + + local function run() + for i = 1, START(N) do + ctx() + end + end + + ctx(1, function() + run() + end) +end) + +local depth = 10 +ROOT_BENCH(`get context (depth={depth})`, function() + + local ctx = context() + + local function run() + for i = 1, START(N) do + ctx() + end + end + + local function nest_effect(fn) + untrack(function() + effect(fn) + return nil + end) + end + + local f = run + for i = 1, depth - 1 do + local f_inner = f + f = function() + nest_effect(f_inner) + end + end + + ctx(1, function() + f() + end) +end) + N *= 1024 TITLE "aggregate" diff --git a/test/tests.luau b/test/tests.luau index fe2778d..55642ff 100644 --- a/test/tests.luau +++ b/test/tests.luau @@ -2155,6 +2155,108 @@ TEST("read()", wrap_root(function() end end)) +TEST("context()", function() + local root = vide.root + local context = vide.context + local effect = vide.effect + local untrack = vide.untrack + local show = vide.show + + do CASE "set context" + local ctx = context() + + root(function() + ctx(1, function() + CHECK(ctx() == 1) + + effect(function() + CHECK(ctx() == 1) + end) + end) + end) + end + + do CASE "set context outside of scope" + local ctx = context() + + local ok = pcall(function() + ctx(1, function() end) + end) + + CHECK(not ok) + end + + do CASE "get default context" + local ctx = context(1) + + CHECK(ctx() == 1) + + root(function() + ctx(2, function() + CHECK(ctx() == 2) + end) + + CHECK(ctx() == 1) + end) + end + + do CASE "context cascade" + local ctx = context(1) + local ctx2 = context() + + root(function() + ctx(2, function() + ctx2(true, function() + show(function() return true end, function() + ctx(3, function() + effect(function() + CHECK(ctx() == 3) + untrack(function() + effect(function() + CHECK(ctx() == 3) + end) + CHECK(ctx2() == true) + return {} + end) + end) + CHECK(ctx() == 3) + end) + CHECK(ctx() == 2) + return {} + end) + end) + CHECK(ctx() == 2) + end) + + CHECK(ctx() == 1) + end) + end + + do CASE "nil context" + local ctx = context(nil) + + root(function() + CHECK(ctx() == nil) + ctx(nil, function() + CHECK(ctx() == nil) + ctx(true :: any, function() + CHECK(ctx() == true) + ctx(nil, function() + effect(function() + untrack(function() + effect(function() + CHECK(ctx() == nil) + end) + return {} + end) + end) + end) + end) + end) + end) + end +end) + TEST("nested effects cases", function() local vide = require "src/init" local source = vide.source