From 72a20bdb67155bda53e7418a7f0d95e627d11a6c Mon Sep 17 00:00:00 2001 From: odersky Date: Sat, 10 Aug 2024 14:21:04 +0200 Subject: [PATCH] Improve Contains handling Make use of enclosing Contains assumptions to improve the subsumes logic. --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 18 ++++++++ .../src/dotty/tools/dotc/cc/CaptureRef.scala | 4 ++ .../src/dotty/tools/dotc/cc/CaptureSet.scala | 8 +++- .../dotty/tools/dotc/cc/CheckCaptures.scala | 46 +++++++++++-------- .../dotty/tools/dotc/core/Definitions.scala | 2 +- tests/pos-custom-args/captures/i21313.scala | 11 ++++- 6 files changed, 66 insertions(+), 23 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index a2d2d2cf358c..9b7d2b90ed1a 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -713,3 +713,21 @@ extension (self: Type) case _ => self +/** An extractor for a contains argument */ +object ContainsImpl: + def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] = + tree.fun.tpe.widen match + case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl => + tree.args match + case csArg :: refArg :: Nil => Some((csArg, refArg)) + case _ => None + case _ => None + +/** An extractor for a contains parameter */ +object ContainsParam: + def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] = + sym.info.dealias match + case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil) + if tycon.typeSymbol == defn.Caps_ContainsTrait + && cs.typeSymbol.isAbstractOrParamType => Some((cs, ref)) + case _ => None diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala b/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala index 6578da89bbf8..f00c6869cd80 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureRef.scala @@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType: case x1: SingletonCaptureRef => x1.subsumes(y) case _ => false case x: TermParamRef => subsumesExistentially(x, y) + case x: TypeRef => assumedContainsOf(x).contains(y) case _ => false + def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] = + CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty) + end CaptureRef trait SingletonCaptureRef extends SingletonType, CaptureRef diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 1d09b9dc5f20..25d8e0bc6506 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property} import typer.ErrorReporting.Addenda import TypeComparer.subsumesExistentially import util.common.alwaysTrue -import scala.collection.mutable +import scala.collection.{mutable, immutable} import CCState.* /** A class for capture sets. Capture sets can be constants or variables. @@ -1125,6 +1125,12 @@ object CaptureSet: foldOver(cs, t) collect(CaptureSet.empty, tp) + type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]] + val AssumedContains: Property.Key[AssumedContains] = Property.Key() + + def assumedContains(using Context): AssumedContains = + ctx.property(AssumedContains).getOrElse(immutable.Map.empty) + private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key() /** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index dbf01915122d..51cf362ca667 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer: i"Sealed type variable $pname", "be instantiated to", i"This is often caused by a local capability$where\nleaking as part of its result.", tree.srcPos) - val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) - if meth == defn.Caps_containsImpl then checkContains(tree) - res + try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) + finally checkContains(tree) end recheckTypeApply /** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked * capability and assert that `{r} <:CS`. */ - def checkContains(tree: TypeApply)(using Context): Unit = - tree.fun.knownType.widen match - case fntpe: PolyType => - tree.args match - case csArg :: refArg :: Nil => - val cs = csArg.knownType.captureSet - val ref = refArg.knownType - capt.println(i"check contains $cs , $ref") - ref match - case ref: CaptureRef if ref.isTracked => - checkElem(ref, cs, tree.srcPos) - case _ => - report.error(em"$refArg is not a tracked capability", refArg.srcPos) - case _ => - case _ => + def checkContains(tree: TypeApply)(using Context): Unit = tree match + case ContainsImpl(csArg, refArg) => + val cs = csArg.knownType.captureSet + val ref = refArg.knownType + capt.println(i"check contains $cs , $ref") + ref match + case ref: CaptureRef if ref.isTracked => + checkElem(ref, cs, tree.srcPos) + case _ => + report.error(em"$refArg is not a tracked capability", refArg.srcPos) + case _ => override def recheckBlock(tree: Block, pt: Type)(using Context): Type = inNestedLevel(super.recheckBlock(tree, pt)) @@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer: val localSet = capturedVars(sym) if !localSet.isAlwaysEmpty then curEnv = Env(sym, EnvKind.Regular, localSet, curEnv) + + // ctx with AssumedContains entries for each Contains parameter + val bodyCtx = + var ac = CaptureSet.assumedContains + for paramSyms <- sym.paramSymss do + for case ContainsParam(cs, ref) <- paramSyms do + ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref) + if ac.isEmpty then ctx + else ctx.withProperty(CaptureSet.AssumedContains, Some(ac)) + inNestedLevel: // TODO: needed here? - try checkInferredResult(super.recheckDefDef(tree, sym), tree) + try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree) finally if !sym.isAnonymousFunction then // Anonymous functions propagate their type to the enclosing environment // so it is not in general sound to interpolate their types. interpolateVarsIn(tree.tpt) curEnv = saved - + end recheckDefDef + /** If val or def definition with inferred (result) type is visible * in other compilation units, check that the actual inferred type * conforms to the expected type where all inferred capture sets are dropped. diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1d2f2b05feb4..8981aa4aa6ac 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1002,7 +1002,7 @@ class Definitions { @tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox") @tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox") @tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg") - @tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability") + @tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains") @tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl") @tu lazy val PureClass: Symbol = requiredClass("scala.Pure") diff --git a/tests/pos-custom-args/captures/i21313.scala b/tests/pos-custom-args/captures/i21313.scala index 2fda6c0c0e45..b388b6487cb5 100644 --- a/tests/pos-custom-args/captures/i21313.scala +++ b/tests/pos-custom-args/captures/i21313.scala @@ -1,7 +1,16 @@ import caps.CapSet trait Async: - def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T + def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T = + val x: Async^{this} = ??? + val y: Async^{Cap^} = x + val ac: Async^ = ??? + def f(using caps.Contains[Cap, ac.type]) = + val x2: Async^{this} = ??? + val y2: Async^{Cap^} = x2 + val x3: Async^{ac} = ??? + val y3: Async^{Cap^} = x3 + ??? trait Source[+T, Cap^]: final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.