Skip to content

Commit

Permalink
Flatten captured variables in lambdas using the correct environment
Browse files Browse the repository at this point in the history
  • Loading branch information
liufengyun committed Sep 9, 2024
1 parent 5fd1d5d commit c52f418
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ class Objects(using Context @constructorOnly):
* @param klass The enclosing class of the anonymous function's creation site
*/
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data)(using @constructorOnly ctx: Context) extends ValueElement:
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ", " + freeVals + ", " + env.show + ")"

val freeVals: Set[Symbol] = computeFreeVars()(using ctx)

def computeFreeVars()(using ctx: Context): Set[Symbol] =
def computeFreeVars()(using Context): Set[Symbol] =
// TODO: compute captures transitively for local methods
val refs = mutable.Set.empty[Symbol]
val defs = mutable.Set.empty[Symbol]
Expand All @@ -263,10 +263,18 @@ class Objects(using Context @constructorOnly):
traverser.traverse(code)
refs.diff(defs).toSet

def flatten: Iterable[Value | Addr] =
// Early compute the flattened value to avoid capturing `ctx`
val flatten = computeflatten()(using ctx)

def computeflatten()(using Context): Iterable[Value | Addr] =
val captured = freeVals.flatMap: x =>
val resOpt = Env.get(x)(using env)
resOpt.map(_ :: Nil).getOrElse(Nil)
Env.resolveEnv(x.enclosingMethod, thisV, env) match
case Some(thisV -> env) =>
val resOpt = Env.get(x)(using env)
resOpt.map(_ :: Nil).getOrElse(Nil)

case None =>
Nil

captured ++ Vector(thisV)

Expand Down Expand Up @@ -833,7 +841,7 @@ class Objects(using Context @constructorOnly):
* @param needResolve Whether the target of the call needs resolution?
*/
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log(
"call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap.size = " + Heap.getHeapData().size, printer, (_: Value).show) {
"call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap = " + Heap.getHeapData().size, printer, (_: Value).show) {

value.filterClass(meth.owner) match
case Cold =>
Expand Down Expand Up @@ -1281,9 +1289,10 @@ class Objects(using Context @constructorOnly):
* @param klass The enclosing class where the expression is located.
* @param ctx The context where `eval` is called.
*/
def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, ctx: EvalContext = EvalContext.Other): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", heap.size = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) {
cache.cachedEval(thisV, expr, ctx) { expr => cases(expr, thisV, klass) }
}
def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, ctx: EvalContext = EvalContext.Other): Contextual[Value] =
log("evaluating " + expr.show + ", this = " + thisV.show + ", heap = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) {
cache.cachedEval(thisV, expr, ctx) { expr => cases(expr, thisV, klass) }
}


/** Evaluate a list of expressions */
Expand All @@ -1298,7 +1307,9 @@ class Objects(using Context @constructorOnly):
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the expression `expr` is located.
*/
def cases(expr: Tree, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
def cases(expr: Tree, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = log(
"evaluating " + expr.show + ", this = " + thisV.show + ", heap = " + Heap.getHeapData().size + " in " + klass.show, printer, (_: Value).show) {

val trace2 = trace.add(expr)

expr match
Expand Down

0 comments on commit c52f418

Please sign in to comment.