Skip to content

Commit

Permalink
limit the number of intermediate solutions generated in the constrain…
Browse files Browse the repository at this point in the history
…t solver to avoid exponential blowup
  • Loading branch information
byorgey committed May 25, 2024
1 parent 0b8fe87 commit baaf47e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 40 deletions.
5 changes: 4 additions & 1 deletion src/Disco/Interactive/Commands.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1029,12 +1029,15 @@ typeCheckCmd =
, parser = TypeCheck <$> parseTermOrOp
}

maxInferredTypes :: Int
maxInferredTypes = 16

handleTypeCheck ::
Members '[Error DiscoError, Input TopInfo, LFresh, Output (Message ())] r =>
REPLExpr 'CTypeCheck ->
Sem r ()
handleTypeCheck (TypeCheck t) = do
asigs <- typecheckTop $ inferTop t
asigs <- typecheckTop $ inferTop maxInferredTypes t
sigs <- runFresh . mapInput (view (replModInfo . miTydefs)) $ thin $ NE.map snd asigs
let (toShow, extra) = NE.splitAt 8 sigs
when (length sigs > 1) $ info "This expression has multiple possible types. Some examples:"
Expand Down
46 changes: 33 additions & 13 deletions src/Disco/Typecheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import qualified Disco.Subst as Subst
import Disco.Syntax.Operators
import Disco.Syntax.Prims
import Disco.Typecheck.Constraints
import Disco.Typecheck.Solve (solveConstraint)
import Disco.Typecheck.Solve (SolutionLimit (..), solveConstraint)
import Disco.Typecheck.Util
import Disco.Types
import Disco.Types.Rules
Expand All @@ -50,6 +50,7 @@ import Polysemy.Error
import Polysemy.Input
import Polysemy.Output
import Polysemy.Reader
import Polysemy.State (evalState)
import Polysemy.Writer
import Text.EditDistance (defaultEditCosts, restrictedDamerauLevenshteinDistance)
import Unbound.Generics.LocallyNameless (
Expand Down Expand Up @@ -281,6 +282,8 @@ checkDefn ::
TermDefn ->
Sem r Defn
checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ do
debug "======================================================================"
debug "Checking definition:"
-- Check that all clauses have the same number of patterns
checkNumPats clauses

Expand All @@ -295,7 +298,7 @@ checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $
-- patterns, and lazily unrolling type definitions along the way.
(patTys, bodyTy) <- decomposeDefnTy (numPats (NE.head clauses)) ty

((acs, _), thetas) <- solve $ do
((acs, _), thetas) <- solve 1 $ do
aclauses <- forAll nms $ mapM (checkClause patTys bodyTy) clauses
return (aclauses, ty)

Expand All @@ -311,7 +314,7 @@ checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $
-- patterns don't match across different clauses
| otherwise = return ()

-- \| Check a clause of a definition against a list of pattern types and a body type.
-- Check a clause of a definition against a list of pattern types and a body type.
checkClause ::
Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r =>
[Type] ->
Expand Down Expand Up @@ -359,7 +362,10 @@ checkProperty ::
Property ->
Sem r AProperty
checkProperty prop = do
(at, thetas) <- solve $ check prop TyProp
debug "======================================================================"
debug "Checking property:"
debugPretty prop
(at, thetas) <- solve 1 $ check prop TyProp
-- XXX do we need to default container variables here?
return $ applySubst (NE.head thetas) at

Expand Down Expand Up @@ -454,24 +460,30 @@ inferTop' ::
Members '[Output (Message ann), Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r =>
Term ->
Sem r (ATerm, PolyType)
inferTop' = fmap NE.head . inferTop
inferTop' t = NE.head <$> inferTop 1 t

-- | Top-level type inference algorithm: infer some possible
-- (polymorphic) types for a term by running type inference, solving
-- the resulting constraints, and quantifying over any remaining
-- type variables.
-- | Top-level type inference algorithm: infer up to the requested max
-- number of possible (polymorphic) types for a term by running type
-- inference, solving the resulting constraints, and quantifying
-- over any remaining type variables.
inferTop ::
Members '[Output (Message ann), Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r =>
Int ->
Term ->
Sem r (NonEmpty (ATerm, PolyType))
inferTop t = do
inferTop lim t = do
-- Run inference on the term and try to solve the resulting
-- constraints.
(at, thetas) <- solve $ infer t
debug "======================================================================"
debug "Inferring the type of:"
debugPretty t
(at, thetas) <- solve lim $ infer t

debug "Final annotated term (before substitution and container monomorphizing):"
debugPretty at

-- XXX include this container variable stuff in solution limits...?

-- Quantify over any remaining type variables and return
-- the term along with the resulting polymorphic type.
return $ do
Expand Down Expand Up @@ -504,7 +516,11 @@ checkTop ::
PolyType ->
Sem r ATerm
checkTop t ty = do
(at, theta) <- solve $ checkPolyTy t ty
debug "======================================================================"
debug "Checking the type of:"
debugPretty t
debugPretty ty
(at, theta) <- solve 1 $ checkPolyTy t ty
return $ applySubst (NE.head theta) at

--------------------------------------------------
Expand Down Expand Up @@ -1746,7 +1762,11 @@ isSubPolyType (Forall b1) (Forall b2) = do
(as1, ty1) <- unbind b1
(as2, ty2) <- unbind b2
let c = CAll (bind as1 (CSub ty1 (substs (zip as2 (map TyVar as1)) ty2)))
ss <- runError (solveConstraint c)
debug "======================================================================"
debug "Checking subtyping..."
debugPretty (Forall b1)
debugPretty (Forall b2)
ss <- runError (evalState (SolutionLimit 1) (solveConstraint c))
return (either (const False) (not . P.null) ss)

thin :: Members '[Input TyDefCtx, Output (Message ann), Fresh] r => NonEmpty PolyType -> Sem r (NonEmpty PolyType)
Expand Down
82 changes: 61 additions & 21 deletions src/Disco/Typecheck/Solve.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Unbound.Generics.LocallyNameless (

import Control.Arrow ((&&&), (***))
import Control.Lens hiding (use, (%=), (.=))
import Control.Monad (forM, unless, zipWithM)
import Control.Monad (forM, join, unless, zipWithM)
import Data.Bifunctor (first, second)
import Data.Coerce
import Data.Either (partitionEithers)
Expand Down Expand Up @@ -89,8 +89,8 @@ instance Semigroup SolveError where
--------------------------------------------------
-- Error utilities

runSolve :: Sem (Fresh ': Error SolveError ': r) a -> Sem r (Either SolveError a)
runSolve = runError . runFresh
runSolve :: SolutionLimit -> Sem (State SolutionLimit ': Fresh ': Error SolveError ': r) a -> Sem r (Either SolveError a)
runSolve lim = runError . runFresh . evalState lim

-- | Run a list of actions, and return the results from those which do
-- not throw an error. If all of them throw an error, rethrow the
Expand All @@ -102,6 +102,32 @@ filterErrors ms = do
Left (e :| _) -> throw e
Right (_, as) -> return as

--------------------------------------------------
-- Solution limits

-- | Max number of solutions to generate.
newtype SolutionLimit = SolutionLimit {getSolutionLimit :: Int}

-- | Register the fact that we found one solution, by decrementing the
-- solution limit.
countSolution :: Member (State SolutionLimit) r => Sem r ()
countSolution = modify (SolutionLimit . subtract 1 . getSolutionLimit)

-- | Run a subcomputation conditional on the solution limit still
-- being positive. If the solution limit has reached zero, stop
-- early.
withSolutionLimit ::
(Member (State SolutionLimit) r, Member (Output (Message ann)) r, Monoid a) =>
Sem r a ->
Sem r a
withSolutionLimit m = do
SolutionLimit lim <- get
case lim of
0 -> do
debug "Reached solution limit, stopping early..."
return mempty
_ -> m

--------------------------------------------------
-- Simple constraints

Expand Down Expand Up @@ -239,7 +265,7 @@ lkup messg m k = fromMaybe (error errMsg) (M.lookup k m)
-- Top-level solver algorithm

solveConstraint ::
Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx] r =>
Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx, State SolutionLimit] r =>
Constraint ->
Sem r (NonEmpty S)
solveConstraint c = do
Expand All @@ -248,6 +274,7 @@ solveConstraint c = do
-- list of possible constraint sets; each one consists of equational
-- and subtyping constraints in addition to qualifiers.

debug "============================================================"
debug "Solving:"
debugPretty c

Expand All @@ -260,11 +287,13 @@ solveConstraint c = do
sconcat <$> filterErrors (NE.map (uncurry solveConstraintChoice) qcList)

solveConstraintChoice ::
Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx] r =>
Members '[Fresh, Error SolveError, Output (Message ann), Input TyDefCtx, State SolutionLimit] r =>
TyVarInfoMap ->
[SimpleConstraint] ->
Sem r (NonEmpty S)
solveConstraintChoice quals cs = do
debug "solveConstraintChoice"

debugPretty quals
debug $ vcat (map pretty' cs)

Expand Down Expand Up @@ -955,6 +984,9 @@ ubsBySort, lbsBySort :: TyVarInfoMap -> RelMap -> [BaseTy] -> Sort -> Set (Name
ubsBySort vm rm = allBySort vm rm SuperTy
lbsBySort vm rm = allBySort vm rm SubTy

maxSolutions :: Int
maxSolutions = 16

-- | From the constraint graph, build the sets of sub- and super- base
-- types of each type variable, as well as the sets of sub- and
-- supertype variables. For each type variable x in turn, try to
Expand All @@ -969,11 +1001,14 @@ lbsBySort vm rm = allBySort vm rm SubTy
-- predecessors in this case, since it seems nice to default to
-- "simpler" types lower down in the subtyping chain.
solveGraph ::
Members '[Fresh, Error SolveError, Output (Message ann)] r =>
Members '[Fresh, Error SolveError, Output (Message ann), State SolutionLimit] r =>
TyVarInfoMap ->
Graph UAtom ->
Sem r [S]
solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
solveGraph vm g = do
debug "Solving graph..."
debugPretty g
map (atomToTypeSubst . unifyWCC) <$> go topRelMap
where
unifyWCC :: Substitution BaseTy -> Substitution Atom
unifyWCC s = compose (map mkEquateSubst wccVarGroups) @@ fmap ABase s
Expand Down Expand Up @@ -1035,13 +1070,17 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
fromVar _ = error "Impossible! UB but uisVar."

go ::
Members '[Fresh, Output (Message ann)] r =>
Members '[Fresh, Output (Message ann), State SolutionLimit] r =>
RelMap ->
Sem r (Set (Substitution BaseTy))
go relMap@(RelMap rm) =
debugPretty relMap >> case as of
Sem r [Substitution BaseTy]
go relMap@(RelMap rm) = withSolutionLimit $ do
debugPretty relMap
case as of
-- No variables left that have base type constraints.
[] -> return $ S.singleton idS
[] -> do
-- Found a solution, decrement the counter.
countSolution
return [idS]
-- Solve one variable at a time. See below.
(a : _) -> do
debug $ "Solving for" <+> pretty' a
Expand All @@ -1059,9 +1098,9 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
-- anything (indeed, some variables might not be keys if
-- they have an empty sort), so it doesn't matter if old
-- variables hang around in it.
ss' <- forM (S.toList ss) $ \s ->
S.map (@@ s) <$> go (substRel a (fromJust $ Subst.lookup (coerce a) s) relMap)
return (S.unions ss')
ss' <- forM ss $ \s ->
map (@@ s) <$> go (substRel a (fromJust $ Subst.lookup (coerce a) s) relMap)
return (join ss')
where
-- NOTE we can't solve a bunch in parallel! Might end up
-- assigning them conflicting solutions if some depend on
Expand Down Expand Up @@ -1109,7 +1148,7 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
-- them.

-- Solve for a variable, returning all possible substitutions.
solveVar :: Name Type -> Set (Substitution BaseTy)
solveVar :: Name Type -> [Substitution BaseTy]
solveVar v =
case ((v, SuperTy), (v, SubTy)) & over both (S.toList . baseRels . lkup "solveGraph.solveVar" rm) of
-- No sub- or supertypes; the only way this can happen is
Expand Down Expand Up @@ -1140,11 +1179,11 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
([], []) ->
-- Debug.trace (show v ++ " has no sub- or supertypes.")
-- Pick some base type with an appropriate sort.
S.map (coerce v |->) $ S.fromList (filter (`hasSort` getSort vm v) [N, Z, F, Q, B, C])
map (coerce v |->) $ filter (`hasSort` getSort vm v) [N, Z, F, Q, B, C]
-- Only supertypes. Just assign a to their inf, if one exists.
(bsupers, []) ->
-- Debug.trace (show v ++ " has only supertypes (" ++ show bsupers ++ ")") $
S.map (coerce v |->) $
map (coerce v |->) . S.toList $
lbsBySort
vm
relMap
Expand All @@ -1158,7 +1197,7 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
-- Debug.trace ("relmap: " ++ show relMap) $
-- Debug.trace ("sort for " ++ show v ++ ": " ++ show (getSort vm v)) $
-- Debug.trace ("relvars: " ++ show (varRels (relMap ! (v,SubTy)))) $
S.map (coerce v |->) $
map (coerce v |->) . S.toList $
ubsBySort
vm
relMap
Expand All @@ -1171,5 +1210,6 @@ solveGraph vm g = map (atomToTypeSubst . unifyWCC) . S.toList <$> go topRelMap
let mub = glbBySort vm relMap bsupers (getSort vm v) (varRels (rm ! (v, SuperTy)))
mlb = lubBySort vm relMap bsubs (getSort vm v) (varRels (rm ! (v, SubTy)))
in case (mlb, mub) of
(Just lb, Just ub) -> S.map (coerce v |->) (S.fromList (filter (`isSubB` ub) (supertypes lb)))
_ -> S.empty
(Just lb, Just ub) ->
map (coerce v |->) (filter (`isSubB` ub) (supertypes lb))
_ -> []
11 changes: 6 additions & 5 deletions src/Disco/Typecheck/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,17 @@ withConstraint :: Sem (Writer Constraint ': r) a -> Sem r (a, Constraint)
withConstraint = fmap swap . runWriter

-- | Run a computation and solve its generated constraint, returning
-- all the possible resulting substitutions (or failing with an
-- error). Note that this locally dispatches the constraint writer
-- effect.
-- up to the requested number of possible resulting substitutions
-- (or failing with an error). Note that this locally dispatches
-- the constraint writer and solution limit effects.
solve ::
Members '[Reader TyDefCtx, Error TCError, Output (Message ann)] r =>
Int ->
Sem (Writer Constraint ': r) a ->
Sem r (a, NonEmpty S)
solve m = do
solve lim m = do
(a, c) <- withConstraint m
res <- runSolve . inputToReader . solveConstraint $ c
res <- runSolve (SolutionLimit lim) . inputToReader . solveConstraint $ c
case res of
Left e -> throw (Unsolvable e)
Right ss -> return (a, ss)
Expand Down

0 comments on commit baaf47e

Please sign in to comment.