From baaf47eb6ba990480cd963e5c6bf729678729fa6 Mon Sep 17 00:00:00 2001 From: Brent Yorgey Date: Fri, 24 May 2024 17:05:51 -0500 Subject: [PATCH] limit the number of intermediate solutions generated in the constraint solver to avoid exponential blowup --- src/Disco/Interactive/Commands.hs | 5 +- src/Disco/Typecheck.hs | 46 ++++++++++++----- src/Disco/Typecheck/Solve.hs | 82 +++++++++++++++++++++++-------- src/Disco/Typecheck/Util.hs | 11 +++-- 4 files changed, 104 insertions(+), 40 deletions(-) diff --git a/src/Disco/Interactive/Commands.hs b/src/Disco/Interactive/Commands.hs index 900f399f..e9c143bd 100644 --- a/src/Disco/Interactive/Commands.hs +++ b/src/Disco/Interactive/Commands.hs @@ -1029,12 +1029,15 @@ typeCheckCmd = , parser = TypeCheck <$> parseTermOrOp } +maxInferredTypes :: Int +maxInferredTypes = 16 + handleTypeCheck :: Members '[Error DiscoError, Input TopInfo, LFresh, Output (Message ())] r => REPLExpr 'CTypeCheck -> Sem r () handleTypeCheck (TypeCheck t) = do - asigs <- typecheckTop $ inferTop t + asigs <- typecheckTop $ inferTop maxInferredTypes t sigs <- runFresh . mapInput (view (replModInfo . miTydefs)) $ thin $ NE.map snd asigs let (toShow, extra) = NE.splitAt 8 sigs when (length sigs > 1) $ info "This expression has multiple possible types. Some examples:" diff --git a/src/Disco/Typecheck.hs b/src/Disco/Typecheck.hs index 19f99e4a..84a35249 100644 --- a/src/Disco/Typecheck.hs +++ b/src/Disco/Typecheck.hs @@ -41,7 +41,7 @@ import qualified Disco.Subst as Subst import Disco.Syntax.Operators import Disco.Syntax.Prims import Disco.Typecheck.Constraints -import Disco.Typecheck.Solve (solveConstraint) +import Disco.Typecheck.Solve (SolutionLimit (..), solveConstraint) import Disco.Typecheck.Util import Disco.Types import Disco.Types.Rules @@ -50,6 +50,7 @@ import Polysemy.Error import Polysemy.Input import Polysemy.Output import Polysemy.Reader +import Polysemy.State (evalState) import Polysemy.Writer import Text.EditDistance (defaultEditCosts, restrictedDamerauLevenshteinDistance) import Unbound.Generics.LocallyNameless ( @@ -281,6 +282,8 @@ checkDefn :: TermDefn -> Sem r Defn checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ do + debug "======================================================================" + debug "Checking definition:" -- Check that all clauses have the same number of patterns checkNumPats clauses @@ -295,7 +298,7 @@ checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ -- patterns, and lazily unrolling type definitions along the way. (patTys, bodyTy) <- decomposeDefnTy (numPats (NE.head clauses)) ty - ((acs, _), thetas) <- solve $ do + ((acs, _), thetas) <- solve 1 $ do aclauses <- forAll nms $ mapM (checkClause patTys bodyTy) clauses return (aclauses, ty) @@ -311,7 +314,7 @@ checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ -- patterns don't match across different clauses | otherwise = return () - -- \| Check a clause of a definition against a list of pattern types and a body type. + -- Check a clause of a definition against a list of pattern types and a body type. checkClause :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => [Type] -> @@ -359,7 +362,10 @@ checkProperty :: Property -> Sem r AProperty checkProperty prop = do - (at, thetas) <- solve $ check prop TyProp + debug "======================================================================" + debug "Checking property:" + debugPretty prop + (at, thetas) <- solve 1 $ check prop TyProp -- XXX do we need to default container variables here? return $ applySubst (NE.head thetas) at @@ -454,24 +460,30 @@ inferTop' :: Members '[Output (Message ann), Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r => Term -> Sem r (ATerm, PolyType) -inferTop' = fmap NE.head . inferTop +inferTop' t = NE.head <$> inferTop 1 t --- | Top-level type inference algorithm: infer some possible --- (polymorphic) types for a term by running type inference, solving --- the resulting constraints, and quantifying over any remaining --- type variables. +-- | Top-level type inference algorithm: infer up to the requested max +-- number of possible (polymorphic) types for a term by running type +-- inference, solving the resulting constraints, and quantifying +-- over any remaining type variables. inferTop :: Members '[Output (Message ann), Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r => + Int -> Term -> Sem r (NonEmpty (ATerm, PolyType)) -inferTop t = do +inferTop lim t = do -- Run inference on the term and try to solve the resulting -- constraints. - (at, thetas) <- solve $ infer t + debug "======================================================================" + debug "Inferring the type of:" + debugPretty t + (at, thetas) <- solve lim $ infer t debug "Final annotated term (before substitution and container monomorphizing):" debugPretty at + -- XXX include this container variable stuff in solution limits...? + -- Quantify over any remaining type variables and return -- the term along with the resulting polymorphic type. return $ do @@ -504,7 +516,11 @@ checkTop :: PolyType -> Sem r ATerm checkTop t ty = do - (at, theta) <- solve $ checkPolyTy t ty + debug "======================================================================" + debug "Checking the type of:" + debugPretty t + debugPretty ty + (at, theta) <- solve 1 $ checkPolyTy t ty return $ applySubst (NE.head theta) at -------------------------------------------------- @@ -1746,7 +1762,11 @@ isSubPolyType (Forall b1) (Forall b2) = do (as1, ty1) <- unbind b1 (as2, ty2) <- unbind b2 let c = CAll (bind as1 (CSub ty1 (substs (zip as2 (map TyVar as1)) ty2))) - ss <- runError (solveConstraint c) + debug "======================================================================" + debug "Checking subtyping..." + debugPretty (Forall b1) + debugPretty (Forall b2) + ss <- runError (evalState (SolutionLimit 1) (solveConstraint c)) return (either (const False) (not . P.null) ss) thin :: Members '[Input TyDefCtx, Output (Message ann), Fresh] r => NonEmpty PolyType -> Sem r (NonEmpty PolyType) diff --git a/src/Disco/Typecheck/Solve.hs b/src/Disco/Typecheck/Solve.hs index 3e792a0b..aecafc34 100644 --- a/src/Disco/Typecheck/Solve.hs +++ b/src/Disco/Typecheck/Solve.hs @@ -26,7 +26,7 @@ import Unbound.Generics.LocallyNameless ( import Control.Arrow ((&&&), (***)) import Control.Lens hiding (use, (%=), (.=)) -import Control.Monad (forM, unless, zipWithM) +import Control.Monad (forM, join, unless, zipWithM) import Data.Bifunctor (first, second) import Data.Coerce import Data.Either (partitionEithers) @@ -89,8 +89,8 @@ instance Semigroup SolveError where -------------------------------------------------- -- Error utilities -runSolve :: Sem (Fresh ': Error SolveError ': r) a -> Sem r (Either SolveError a) -runSolve = runError . runFresh +runSolve :: SolutionLimit -> Sem (State SolutionLimit ': Fresh ': Error SolveError ': r) a -> Sem r (Either SolveError a) +runSolve lim = runError . runFresh . evalState lim -- | Run a list of actions, and return the results from those which do -- not throw an error. If all of them throw an error, rethrow the @@ -102,6 +102,32 @@ filterErrors ms = do Left (e :| _) -> throw e Right (_, as) -> return as +-------------------------------------------------- +-- Solution limits + +-- | Max number of solutions to generate. +newtype SolutionLimit = SolutionLimit {getSolutionLimit :: Int} + +-- | Register the fact that we found one solution, by decrementing the +-- solution limit. +countSolution :: Member (State SolutionLimit) r => Sem r () +countSolution = modify (SolutionLimit . subtract 1 . getSolutionLimit) + +-- | Run a subcomputation conditional on the solution limit still +-- being positive. If the solution limit has reached zero, stop +-- early. +withSolutionLimit :: + (Member (State SolutionLimit) r, Member (Output (Message ann)) r, Monoid a) => + Sem r a -> + Sem r a +withSolutionLimit m = do + SolutionLimit lim <- get + case lim of + 0 -> do + debug "Reached solution limit, stopping early..." + return mempty + _ -> m + -------------------------------------------------- -- Simple constraints @@ -239,7 +265,7 @@ lkup messg m k = fromMaybe (error errMsg) (M.lookup k m) -- Top-level solver algorithm solveConstraint :: - Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx] r => + Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx, State SolutionLimit] r => Constraint -> Sem r (NonEmpty S) solveConstraint c = do @@ -248,6 +274,7 @@ solveConstraint c = do -- list of possible constraint sets; each one consists of equational -- and subtyping constraints in addition to qualifiers. + debug "============================================================" debug "Solving:" debugPretty c @@ -260,11 +287,13 @@ solveConstraint c = do sconcat <$> filterErrors (NE.map (uncurry solveConstraintChoice) qcList) solveConstraintChoice :: - Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx] r => + Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx, State SolutionLimit] r => TyVarInfoMap -> [SimpleConstraint] -> Sem r (NonEmpty S) solveConstraintChoice quals cs = do + debug "solveConstraintChoice" + debugPretty quals debug $ vcat (map pretty' cs) @@ -955,6 +984,9 @@ ubsBySort, lbsBySort :: TyVarInfoMap -> RelMap -> [BaseTy] -> Sort -> Set (Name ubsBySort vm rm = allBySort vm rm SuperTy lbsBySort vm rm = allBySort vm rm SubTy +maxSolutions :: Int +maxSolutions = 16 + -- | From the constraint graph, build the sets of sub- and super- base -- types of each type variable, as well as the sets of sub- and -- supertype variables. For each type variable x in turn, try to @@ -969,11 +1001,14 @@ lbsBySort vm rm = allBySort vm rm SubTy -- predecessors in this case, since it seems nice to default to -- "simpler" types lower down in the subtyping chain. solveGraph :: - Members '[Fresh, Error SolveError, Output (Message ann)] r => + Members '[Fresh, Error SolveError, Output (Message ann), State SolutionLimit] r => TyVarInfoMap -> Graph UAtom -> Sem r [S] -solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap +solveGraph vm g = do + debug "Solving graph..." + debugPretty g + map (atomToTypeSubst . unifyWCC) <$> go topRelMap where unifyWCC :: Substitution BaseTy -> Substitution Atom unifyWCC s = compose (map mkEquateSubst wccVarGroups) @@ fmap ABase s @@ -1035,13 +1070,17 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap fromVar _ = error "Impossible! UB but uisVar." go :: - Members '[Fresh, Output (Message ann)] r => + Members '[Fresh, Output (Message ann), State SolutionLimit] r => RelMap -> - Sem r (Set (Substitution BaseTy)) - go relMap@(RelMap rm) = - debugPretty relMap >> case as of + Sem r [Substitution BaseTy] + go relMap@(RelMap rm) = withSolutionLimit $ do + debugPretty relMap + case as of -- No variables left that have base type constraints. - [] -> return $ S.singleton idS + [] -> do + -- Found a solution, decrement the counter. + countSolution + return [idS] -- Solve one variable at a time. See below. (a : _) -> do debug $ "Solving for" <+> pretty' a @@ -1059,9 +1098,9 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap -- anything (indeed, some variables might not be keys if -- they have an empty sort), so it doesn't matter if old -- variables hang around in it. - ss' <- forM (S.toList ss) $ \s -> - S.map (@@ s) <$> go (substRel a (fromJust $ Subst.lookup (coerce a) s) relMap) - return (S.unions ss') + ss' <- forM ss $ \s -> + map (@@ s) <$> go (substRel a (fromJust $ Subst.lookup (coerce a) s) relMap) + return (join ss') where -- NOTE we can't solve a bunch in parallel! Might end up -- assigning them conflicting solutions if some depend on @@ -1109,7 +1148,7 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap -- them. -- Solve for a variable, returning all possible substitutions. - solveVar :: Name Type -> Set (Substitution BaseTy) + solveVar :: Name Type -> [Substitution BaseTy] solveVar v = case ((v, SuperTy), (v, SubTy)) & over both (S.toList . baseRels . lkup "solveGraph.solveVar" rm) of -- No sub- or supertypes; the only way this can happen is @@ -1140,11 +1179,11 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap ([], []) -> -- Debug.trace (show v ++ " has no sub- or supertypes.") -- Pick some base type with an appropriate sort. - S.map (coerce v |->) $ S.fromList (filter (`hasSort` getSort vm v) [N, Z, F, Q, B, C]) + map (coerce v |->) $ filter (`hasSort` getSort vm v) [N, Z, F, Q, B, C] -- Only supertypes. Just assign a to their inf, if one exists. (bsupers, []) -> -- Debug.trace (show v ++ " has only supertypes (" ++ show bsupers ++ ")") $ - S.map (coerce v |->) $ + map (coerce v |->) . S.toList $ lbsBySort vm relMap @@ -1158,7 +1197,7 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap -- Debug.trace ("relmap: " ++ show relMap) $ -- Debug.trace ("sort for " ++ show v ++ ": " ++ show (getSort vm v)) $ -- Debug.trace ("relvars: " ++ show (varRels (relMap ! (v,SubTy)))) $ - S.map (coerce v |->) $ + map (coerce v |->) . S.toList $ ubsBySort vm relMap @@ -1171,5 +1210,6 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap let mub = glbBySort vm relMap bsupers (getSort vm v) (varRels (rm ! (v, SuperTy))) mlb = lubBySort vm relMap bsubs (getSort vm v) (varRels (rm ! (v, SubTy))) in case (mlb, mub) of - (Just lb, Just ub) -> S.map (coerce v |->) (S.fromList (filter (`isSubB` ub) (supertypes lb))) - _ -> S.empty + (Just lb, Just ub) -> + map (coerce v |->) (filter (`isSubB` ub) (supertypes lb)) + _ -> [] diff --git a/src/Disco/Typecheck/Util.hs b/src/Disco/Typecheck/Util.hs index cdcb531e..2e4fd57d 100644 --- a/src/Disco/Typecheck/Util.hs +++ b/src/Disco/Typecheck/Util.hs @@ -142,16 +142,17 @@ withConstraint :: Sem (Writer Constraint ': r) a -> Sem r (a, Constraint) withConstraint = fmap swap . runWriter -- | Run a computation and solve its generated constraint, returning --- all the possible resulting substitutions (or failing with an --- error). Note that this locally dispatches the constraint writer --- effect. +-- up to the requested number of possible resulting substitutions +-- (or failing with an error). Note that this locally dispatches +-- the constraint writer and solution limit effects. solve :: Members '[Reader TyDefCtx, Error TCError, Output (Message ann)] r => + Int -> Sem (Writer Constraint ': r) a -> Sem r (a, NonEmpty S) -solve m = do +solve lim m = do (a, c) <- withConstraint m - res <- runSolve . inputToReader . solveConstraint $ c + res <- runSolve (SolutionLimit lim) . inputToReader . solveConstraint $ c case res of Left e -> throw (Unsolvable e) Right ss -> return (a, ss)