From 03e051abb08e30dfb96b0857d6715917200acf5d Mon Sep 17 00:00:00 2001 From: LeitMoth Date: Sat, 20 Jul 2024 13:52:56 -0500 Subject: [PATCH] User defined types now work! --- src/Disco/Eval.hs | 8 ++-- src/Disco/Exhaustiveness.hs | 25 ++++++----- src/Disco/Exhaustiveness/Constraint.hs | 24 ++++++----- src/Disco/Exhaustiveness/TypeInfo.hs | 59 ++++++++++++++------------ 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/src/Disco/Eval.hs b/src/Disco/Eval.hs index a40bd5f4..a66aee96 100644 --- a/src/Disco/Eval.hs +++ b/src/Disco/Eval.hs @@ -387,7 +387,7 @@ loadParsedDiscoModule' quiet mode resolver inProcess name cm@(Module _ mns _ _ _ m <- runTCM tyctx tydefns $ checkModule name importMap cm -- Check for partial functions - runFresh $ mapM_ checkExhaustive $ (Ctx.elems $ m^.miTermdefs) + runFresh $ mapM_ (checkExhaustive tydefns) (Ctx.elems $ m^.miTermdefs) -- Evaluate all the module definitions and add them to the topEnv. mapError EvalErr $ loadDefsFrom m @@ -449,7 +449,7 @@ loadDef x body = do v <- inputToState @TopInfo . inputTopEnv $ eval body modify @TopInfo $ topEnv %~ Ctx.insert x v -checkExhaustive :: Members '[Fresh, Embed IO] r => Defn -> Sem r () -checkExhaustive (Defn name argsType _ boundClauses) = do +checkExhaustive :: Members '[Fresh, Embed IO] r => TyDefCtx -> Defn -> Sem r () +checkExhaustive tyDefCtx (Defn name argsType _ boundClauses) = do clauses <- NonEmpty.map fst <$> mapM unbind boundClauses - checkClauses name argsType clauses + runReader @TyDefCtx tyDefCtx $ checkClauses name argsType clauses diff --git a/src/Disco/Exhaustiveness.hs b/src/Disco/Exhaustiveness.hs index 0a7ff5f9..f07ee642 100644 --- a/src/Disco/Exhaustiveness.hs +++ b/src/Disco/Exhaustiveness.hs @@ -32,8 +32,9 @@ import qualified Disco.Types as Ty import Polysemy import Text.Show.Pretty (pPrint) import Unbound.Generics.LocallyNameless (Name) +import Polysemy.Reader -checkClauses :: (Members '[Fresh, Embed IO] r) => Name ATerm -> [Ty.Type] -> NonEmpty [APattern] -> Sem r () +checkClauses :: Members '[Fresh, Reader Ty.TyDefCtx, Embed IO] r => Name ATerm -> [Ty.Type] -> NonEmpty [APattern] -> Sem r () checkClauses name types pats = do args <- TI.newVars types cl <- zipWithM (desugarClause args) [1 ..] (NonEmpty.toList pats) @@ -179,7 +180,7 @@ data Ant where ABranch :: Ant -> Ant -> Ant deriving (Show) -ua :: (Member Fresh r) => [C.NormRefType] -> Gdt -> Sem r ([C.NormRefType], Ant) +ua :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> Gdt -> Sem r ([C.NormRefType], Ant) ua nrefs gdt = case gdt of Grhs k -> return ([], AGrhs nrefs k) Branch t1 t2 -> do @@ -195,7 +196,7 @@ ua nrefs gdt = case gdt of n'' <- addLitMulti nrefs $ Literal (x, LitNot dc) return (n'' ++ n', u) -addLitMulti :: (Members '[Fresh] r) => [C.NormRefType] -> Literal -> Sem r [C.NormRefType] +addLitMulti :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> Literal -> Sem r [C.NormRefType] addLitMulti [] _ = return [] addLitMulti (n : ns) lit = do r <- runMaybeT $ addLiteral n lit @@ -205,7 +206,7 @@ addLitMulti (n : ns) lit = do ns' <- addLitMulti ns lit return $ (ctx, cfs) : ns' -addLiteral :: (Members '[Fresh] r) => C.NormRefType -> Literal -> MaybeT (Sem r) C.NormRefType +addLiteral :: Members '[Fresh, Reader Ty.TyDefCtx] r => C.NormRefType -> Literal -> MaybeT (Sem r) C.NormRefType addLiteral (context, constraints) (Literal (x, c)) = case c of LitWasOriginally z -> (context `C.addVars` [x], constraints) `C.addConstraint` (x, C.CWasOriginally z) @@ -234,6 +235,9 @@ joinSpace = foldr1 (join " ") -- as strings are lists of chars. -- Maybe for strings, we just list the top 3 uncovered patterns -- consiting of only postive information, sorted by length? +-- Also, how should we print out sum types? +-- We can do right(thing) or (right thing), or (right(thing)) +-- Which should we chose? prettyInhab :: InhabPat -> String prettyInhab (IPNot []) = "_" prettyInhab (IPNot nots) = "Not{" ++ joinComma (map dcToString nots) ++ "}" @@ -272,17 +276,17 @@ mkIPMatch k pats = then error $ "Wrong number of DataCon args" ++ show (k, pats) else IPIs k pats -findInhabitants :: (Members '[Fresh] r) => [C.NormRefType] -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat]) +findInhabitants :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat]) findInhabitants nrefs args = do a <- forM nrefs (`findAllForNref` args) return $ Poss.anyOf a -findAllForNref :: (Member Fresh r) => C.NormRefType -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat]) +findAllForNref :: Members '[Fresh, Reader Ty.TyDefCtx] r => C.NormRefType -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat]) findAllForNref nref args = do argPats <- forM args (`findVarInhabitants` nref) return $ Poss.allCombinations argPats -findVarInhabitants :: (Members '[Fresh] r) => TI.TypedVar -> C.NormRefType -> Sem r (Poss.Possibilities InhabPat) +findVarInhabitants :: Members '[Fresh, Reader Ty.TyDefCtx] r => TI.TypedVar -> C.NormRefType -> Sem r (Poss.Possibilities InhabPat) findVarInhabitants var nref@(_, cns) = case posMatch of Just (k, args) -> do @@ -290,8 +294,9 @@ findVarInhabitants var nref@(_, cns) = return (mkIPMatch k <$> argPossibilities) Nothing -> case nub negMatches of [] -> Poss.retSingle $ IPNot [] - neg -> - case TI.tyDataCons . TI.getType $ var of + neg -> do + tyCtx <- ask @Ty.TyDefCtx + case TI.tyDataCons (TI.getType var) tyCtx of Nothing -> Poss.retSingle $ IPNot neg Just dcs -> do @@ -313,7 +318,7 @@ findVarInhabitants var nref@(_, cns) = posMatch = C.posMatch constraintsOnX negMatches = C.negMatches constraintsOnX -findRedundant :: (Member Fresh r) => Ant -> [TI.TypedVar] -> Sem r [Int] +findRedundant :: Members '[Fresh, Reader Ty.TyDefCtx] r => Ant -> [TI.TypedVar] -> Sem r [Int] findRedundant ant args = case ant of AGrhs ref i -> do uninhabited <- Poss.none <$> findInhabitants ref args diff --git a/src/Disco/Exhaustiveness/Constraint.hs b/src/Disco/Exhaustiveness/Constraint.hs index 646c97d6..a1ced538 100644 --- a/src/Disco/Exhaustiveness/Constraint.hs +++ b/src/Disco/Exhaustiveness/Constraint.hs @@ -1,7 +1,7 @@ module Disco.Exhaustiveness.Constraint where import Control.Applicative (Alternative) -import Control.Monad (foldM, guard, replicateM) +import Control.Monad (foldM, guard) import Control.Monad.Trans (lift) import Control.Monad.Trans.Maybe (MaybeT, runMaybeT) import Data.List (partition) @@ -9,6 +9,8 @@ import Data.Maybe (isJust, listToMaybe, mapMaybe) import Disco.Effects.Fresh (Fresh) import qualified Disco.Exhaustiveness.TypeInfo as TI import Polysemy +import qualified Disco.Types as Ty +import Polysemy.Reader newtype Context = Context {getCtxVars :: [TI.TypedVar]} deriving (Show, Eq) @@ -47,15 +49,15 @@ onVar x cs = alistLookup (lookupVar x cs) cs type NormRefType = (Context, [ConstraintFor]) -addConstraints :: (Members '[Fresh] r) => NormRefType -> [ConstraintFor] -> MaybeT (Sem r) NormRefType +addConstraints :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> [ConstraintFor] -> MaybeT (Sem r) NormRefType addConstraints = foldM addConstraint -addConstraint :: (Members '[Fresh] r) => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType +addConstraint :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType addConstraint nref@(_, cns) (x, c) = do breakIf $ any (conflictsWith c) (onVar x cns) addConstraintHelper nref (lookupVar x cns, c) -addConstraintHelper :: (Members '[Fresh] r) => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType +addConstraintHelper :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType addConstraintHelper nref@(ctx, cns) cf@(origX, c) = case c of --- Equation (10) CMatch k args -> do @@ -126,17 +128,19 @@ substituteVarIDs y x = map (\(var, c) -> (subst var, c)) -- This function tests if this is true -- NOTE(colin): we may eventually have type constraints -- and we would need to worry pulling them from nref here -inhabited :: (Members '[Fresh] r) => NormRefType -> TI.TypedVar -> Sem r Bool -inhabited n var = case TI.tyDataCons . TI.getType $ var of - Nothing -> return True -- assume opaque types are inhabited - Just constructors -> do - or <$> mapM (instantiate n var) constructors +inhabited :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> TI.TypedVar -> Sem r Bool +inhabited n var = do + tyCtx <- ask @Ty.TyDefCtx + case TI.tyDataCons (TI.getType var) tyCtx of + Nothing -> return True -- assume opaque types are inhabited + Just constructors -> do + or <$> mapM (instantiate n var) constructors -- Attempts to "instantiate" a match of the dataconstructor k on x -- If we can add the MatchDataCon constraint to the normalized refinement -- type without contradiction (a Nothing value), -- then x is inhabited by k and we return true -instantiate :: (Members '[Fresh] r) => NormRefType -> TI.TypedVar -> TI.DataCon -> Sem r Bool +instantiate :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> TI.TypedVar -> TI.DataCon -> Sem r Bool instantiate (ctx, cns) var k = do args <- TI.newVars $ TI.dcTypes k let attempt = (ctx `addVars` args, cns) `addConstraint` (var, CMatch k args) diff --git a/src/Disco/Exhaustiveness/TypeInfo.hs b/src/Disco/Exhaustiveness/TypeInfo.hs index 462158a2..1fd098dc 100644 --- a/src/Disco/Exhaustiveness/TypeInfo.hs +++ b/src/Disco/Exhaustiveness/TypeInfo.hs @@ -1,6 +1,7 @@ module Disco.Exhaustiveness.TypeInfo where import Control.Monad (replicateM) +import qualified Data.Map as M import Disco.AST.Typed (ATerm) import Disco.Effects.Fresh (Fresh, fresh) import qualified Disco.Types as Ty @@ -63,32 +64,38 @@ left tl = DataCon {dcIdent = KLeft, dcTypes = [tl]} right :: Ty.Type -> DataCon right tr = DataCon {dcIdent = KRight, dcTypes = [tr]} -{- -TODO(colin): Fill out the remaining types here -Remaining: - TyVar - , TySkolem - , TyProp - , TyBag - , TySet - , TyGraph - , TyMap - , TyUser -Impossible: - , (:->:) --} -tyDataCons :: Ty.Type -> Maybe [DataCon] -tyDataCons (a Ty.:*: b) = Just [pair a b] -tyDataCons (l Ty.:+: r) = Just [left l, right r] -tyDataCons t@(Ty.TyList a) = Just [cons a t, nil] -tyDataCons Ty.TyVoid = Just [] -tyDataCons Ty.TyUnit = Just [unit] -tyDataCons Ty.TyBool = Just [bool True, bool False] -tyDataCons Ty.TyN = Nothing -tyDataCons Ty.TyZ = Nothing -tyDataCons Ty.TyF = Nothing -tyDataCons Ty.TyQ = Nothing -tyDataCons Ty.TyC = Nothing +-- TODO(colin): ask yorgey, make sure I've done this correctly +-- If I have, and this is enough, I can remove all mentions +-- of type equality constraints in Constraint.hs, +-- the lookup here will have handled that behavoir already +tyDataCons :: Ty.Type -> Ty.TyDefCtx -> Maybe [DataCon] +tyDataCons (Ty.TyUser name args) ctx = case M.lookup name ctx of + Nothing -> error $ "Type definition not found for: " ++ show name + Just (Ty.TyDefBody _argNames typeCon) -> tyDataCons (typeCon args) ctx +tyDataCons (a Ty.:*: b) _ = Just [pair a b] +tyDataCons (l Ty.:+: r) _ = Just [left l, right r] +tyDataCons t@(Ty.TyList a) _ = Just [cons a t, nil] +tyDataCons Ty.TyVoid _ = Just [] +tyDataCons Ty.TyUnit _ = Just [unit] +tyDataCons Ty.TyBool _ = Just [bool True, bool False] +tyDataCons Ty.TyN _ = Nothing +tyDataCons Ty.TyZ _ = Nothing +tyDataCons Ty.TyF _ = Nothing +tyDataCons Ty.TyQ _ = Nothing +tyDataCons Ty.TyC _ = Nothing +tyDataCons (_ Ty.:->: _) _ = error "Functions not allowed in patterns." +tyDataCons (Ty.TySet _) _ = error "Sets not allowed in patterns." +tyDataCons (Ty.TyBag _) _ = error "Bags not allowed in patterns." +-- I'm unsure about these two. +-- They may come up when doing generic stuff, +-- I haven't ecnountered them so far +-- TODO(colin): ask Yorgey about this +tyDataCons (Ty.TyVar _) _ = error "Encountered type var in pattern" +tyDataCons (Ty.TySkolem _) _ = error "Encountered skolem in pattern" +-- Unsure about these as well +tyDataCons (Ty.TyProp) _ = error "Propositions not allowed in patterns." +tyDataCons (Ty.TyMap _ _) _ = error "Maps not allowed in patterns." +tyDataCons (Ty.TyGraph _) _ = error "Graph not allowed in patterns." newName :: (Member Fresh r) => Sem r (Name ATerm) newName = fresh $ s2n ""