From c52f418d4dcbfae0f0d7829cde74b34f0e73ddff Mon Sep 17 00:00:00 2001 From: Fengyun Liu Date: Mon, 9 Sep 2024 23:32:18 +0200 Subject: [PATCH] Flatten captured variables in lambdas using the correct environment --- .../tools/dotc/transform/init/Objects.scala | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 55ea10f51105..da976c143db8 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -236,11 +236,11 @@ class Objects(using Context @constructorOnly): * @param klass The enclosing class of the anonymous function's creation site */ case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data)(using @constructorOnly ctx: Context) extends ValueElement: - def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")" + def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ", " + freeVals + ", " + env.show + ")" val freeVals: Set[Symbol] = computeFreeVars()(using ctx) - def computeFreeVars()(using ctx: Context): Set[Symbol] = + def computeFreeVars()(using Context): Set[Symbol] = // TODO: compute captures transitively for local methods val refs = mutable.Set.empty[Symbol] val defs = mutable.Set.empty[Symbol] @@ -263,10 +263,18 @@ class Objects(using Context @constructorOnly): traverser.traverse(code) refs.diff(defs).toSet - def flatten: Iterable[Value | Addr] = + // Early compute the flattened value to avoid capturing `ctx` + val flatten = computeflatten()(using ctx) + + def computeflatten()(using Context): Iterable[Value | Addr] = val captured = freeVals.flatMap: x => - val resOpt = Env.get(x)(using env) - resOpt.map(_ :: Nil).getOrElse(Nil) + Env.resolveEnv(x.enclosingMethod, thisV, env) match + case Some(thisV -> env) => + val resOpt = Env.get(x)(using env) + resOpt.map(_ :: Nil).getOrElse(Nil) + + case None => + Nil captured ++ Vector(thisV) @@ -833,7 +841,7 @@ class Objects(using Context @constructorOnly): * @param needResolve Whether the target of the call needs resolution? */ def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log( - "call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap.size = " + Heap.getHeapData().size, printer, (_: Value).show) { + "call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap = " + Heap.getHeapData().size, printer, (_: Value).show) { value.filterClass(meth.owner) match case Cold => @@ -1281,9 +1289,10 @@ class Objects(using Context @constructorOnly): * @param klass The enclosing class where the expression is located. * @param ctx The context where `eval` is called. */ - def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, ctx: EvalContext = EvalContext.Other): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", heap.size = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) { - cache.cachedEval(thisV, expr, ctx) { expr => cases(expr, thisV, klass) } - } + def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, ctx: EvalContext = EvalContext.Other): Contextual[Value] = + log("evaluating " + expr.show + ", this = " + thisV.show + ", heap = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) { + cache.cachedEval(thisV, expr, ctx) { expr => cases(expr, thisV, klass) } + } /** Evaluate a list of expressions */ @@ -1298,7 +1307,9 @@ class Objects(using Context @constructorOnly): * @param thisV The value for `C.this` where `C` is represented by the parameter `klass`. * @param klass The enclosing class where the expression `expr` is located. */ - def cases(expr: Tree, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) { + def cases(expr: Tree, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = log( + "evaluating " + expr.show + ", this = " + thisV.show + ", heap = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) { + val trace2 = trace.add(expr) expr match