From 03e051abb08e30dfb96b0857d6715917200acf5d Mon Sep 17 00:00:00 2001
From: LeitMoth
Date: Sat, 20 Jul 2024 13:52:56 -0500
Subject: [PATCH] User defined types now work!
---
src/Disco/Eval.hs | 8 ++--
src/Disco/Exhaustiveness.hs | 25 ++++++-----
src/Disco/Exhaustiveness/Constraint.hs | 24 ++++++-----
src/Disco/Exhaustiveness/TypeInfo.hs | 59 ++++++++++++++------------
4 files changed, 66 insertions(+), 50 deletions(-)
diff --git a/src/Disco/Eval.hs b/src/Disco/Eval.hs
index a40bd5f4..a66aee96 100644
--- a/src/Disco/Eval.hs
+++ b/src/Disco/Eval.hs
@@ -387,7 +387,7 @@ loadParsedDiscoModule' quiet mode resolver inProcess name cm@(Module _ mns _ _ _
m <- runTCM tyctx tydefns $ checkModule name importMap cm
-- Check for partial functions
- runFresh $ mapM_ checkExhaustive $ (Ctx.elems $ m^.miTermdefs)
+ runFresh $ mapM_ (checkExhaustive tydefns) (Ctx.elems $ m^.miTermdefs)
-- Evaluate all the module definitions and add them to the topEnv.
mapError EvalErr $ loadDefsFrom m
@@ -449,7 +449,7 @@ loadDef x body = do
v <- inputToState @TopInfo . inputTopEnv $ eval body
modify @TopInfo $ topEnv %~ Ctx.insert x v
-checkExhaustive :: Members '[Fresh, Embed IO] r => Defn -> Sem r ()
-checkExhaustive (Defn name argsType _ boundClauses) = do
+checkExhaustive :: Members '[Fresh, Embed IO] r => TyDefCtx -> Defn -> Sem r ()
+checkExhaustive tyDefCtx (Defn name argsType _ boundClauses) = do
clauses <- NonEmpty.map fst <$> mapM unbind boundClauses
- checkClauses name argsType clauses
+ runReader @TyDefCtx tyDefCtx $ checkClauses name argsType clauses
diff --git a/src/Disco/Exhaustiveness.hs b/src/Disco/Exhaustiveness.hs
index 0a7ff5f9..f07ee642 100644
--- a/src/Disco/Exhaustiveness.hs
+++ b/src/Disco/Exhaustiveness.hs
@@ -32,8 +32,9 @@ import qualified Disco.Types as Ty
import Polysemy
import Text.Show.Pretty (pPrint)
import Unbound.Generics.LocallyNameless (Name)
+import Polysemy.Reader
-checkClauses :: (Members '[Fresh, Embed IO] r) => Name ATerm -> [Ty.Type] -> NonEmpty [APattern] -> Sem r ()
+checkClauses :: Members '[Fresh, Reader Ty.TyDefCtx, Embed IO] r => Name ATerm -> [Ty.Type] -> NonEmpty [APattern] -> Sem r ()
checkClauses name types pats = do
args <- TI.newVars types
cl <- zipWithM (desugarClause args) [1 ..] (NonEmpty.toList pats)
@@ -179,7 +180,7 @@ data Ant where
ABranch :: Ant -> Ant -> Ant
deriving (Show)
-ua :: (Member Fresh r) => [C.NormRefType] -> Gdt -> Sem r ([C.NormRefType], Ant)
+ua :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> Gdt -> Sem r ([C.NormRefType], Ant)
ua nrefs gdt = case gdt of
Grhs k -> return ([], AGrhs nrefs k)
Branch t1 t2 -> do
@@ -195,7 +196,7 @@ ua nrefs gdt = case gdt of
n'' <- addLitMulti nrefs $ Literal (x, LitNot dc)
return (n'' ++ n', u)
-addLitMulti :: (Members '[Fresh] r) => [C.NormRefType] -> Literal -> Sem r [C.NormRefType]
+addLitMulti :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> Literal -> Sem r [C.NormRefType]
addLitMulti [] _ = return []
addLitMulti (n : ns) lit = do
r <- runMaybeT $ addLiteral n lit
@@ -205,7 +206,7 @@ addLitMulti (n : ns) lit = do
ns' <- addLitMulti ns lit
return $ (ctx, cfs) : ns'
-addLiteral :: (Members '[Fresh] r) => C.NormRefType -> Literal -> MaybeT (Sem r) C.NormRefType
+addLiteral :: Members '[Fresh, Reader Ty.TyDefCtx] r => C.NormRefType -> Literal -> MaybeT (Sem r) C.NormRefType
addLiteral (context, constraints) (Literal (x, c)) = case c of
LitWasOriginally z ->
(context `C.addVars` [x], constraints) `C.addConstraint` (x, C.CWasOriginally z)
@@ -234,6 +235,9 @@ joinSpace = foldr1 (join " ")
-- as strings are lists of chars.
-- Maybe for strings, we just list the top 3 uncovered patterns
-- consiting of only postive information, sorted by length?
+-- Also, how should we print out sum types?
+-- We can do right(thing) or (right thing), or (right(thing))
+-- Which should we chose?
prettyInhab :: InhabPat -> String
prettyInhab (IPNot []) = "_"
prettyInhab (IPNot nots) = "Not{" ++ joinComma (map dcToString nots) ++ "}"
@@ -272,17 +276,17 @@ mkIPMatch k pats =
then error $ "Wrong number of DataCon args" ++ show (k, pats)
else IPIs k pats
-findInhabitants :: (Members '[Fresh] r) => [C.NormRefType] -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat])
+findInhabitants :: Members '[Fresh, Reader Ty.TyDefCtx] r => [C.NormRefType] -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat])
findInhabitants nrefs args = do
a <- forM nrefs (`findAllForNref` args)
return $ Poss.anyOf a
-findAllForNref :: (Member Fresh r) => C.NormRefType -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat])
+findAllForNref :: Members '[Fresh, Reader Ty.TyDefCtx] r => C.NormRefType -> [TI.TypedVar] -> Sem r (Poss.Possibilities [InhabPat])
findAllForNref nref args = do
argPats <- forM args (`findVarInhabitants` nref)
return $ Poss.allCombinations argPats
-findVarInhabitants :: (Members '[Fresh] r) => TI.TypedVar -> C.NormRefType -> Sem r (Poss.Possibilities InhabPat)
+findVarInhabitants :: Members '[Fresh, Reader Ty.TyDefCtx] r => TI.TypedVar -> C.NormRefType -> Sem r (Poss.Possibilities InhabPat)
findVarInhabitants var nref@(_, cns) =
case posMatch of
Just (k, args) -> do
@@ -290,8 +294,9 @@ findVarInhabitants var nref@(_, cns) =
return (mkIPMatch k <$> argPossibilities)
Nothing -> case nub negMatches of
[] -> Poss.retSingle $ IPNot []
- neg ->
- case TI.tyDataCons . TI.getType $ var of
+ neg -> do
+ tyCtx <- ask @Ty.TyDefCtx
+ case TI.tyDataCons (TI.getType var) tyCtx of
Nothing -> Poss.retSingle $ IPNot neg
Just dcs ->
do
@@ -313,7 +318,7 @@ findVarInhabitants var nref@(_, cns) =
posMatch = C.posMatch constraintsOnX
negMatches = C.negMatches constraintsOnX
-findRedundant :: (Member Fresh r) => Ant -> [TI.TypedVar] -> Sem r [Int]
+findRedundant :: Members '[Fresh, Reader Ty.TyDefCtx] r => Ant -> [TI.TypedVar] -> Sem r [Int]
findRedundant ant args = case ant of
AGrhs ref i -> do
uninhabited <- Poss.none <$> findInhabitants ref args
diff --git a/src/Disco/Exhaustiveness/Constraint.hs b/src/Disco/Exhaustiveness/Constraint.hs
index 646c97d6..a1ced538 100644
--- a/src/Disco/Exhaustiveness/Constraint.hs
+++ b/src/Disco/Exhaustiveness/Constraint.hs
@@ -1,7 +1,7 @@
module Disco.Exhaustiveness.Constraint where
import Control.Applicative (Alternative)
-import Control.Monad (foldM, guard, replicateM)
+import Control.Monad (foldM, guard)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Maybe (MaybeT, runMaybeT)
import Data.List (partition)
@@ -9,6 +9,8 @@ import Data.Maybe (isJust, listToMaybe, mapMaybe)
import Disco.Effects.Fresh (Fresh)
import qualified Disco.Exhaustiveness.TypeInfo as TI
import Polysemy
+import qualified Disco.Types as Ty
+import Polysemy.Reader
newtype Context = Context {getCtxVars :: [TI.TypedVar]}
deriving (Show, Eq)
@@ -47,15 +49,15 @@ onVar x cs = alistLookup (lookupVar x cs) cs
type NormRefType = (Context, [ConstraintFor])
-addConstraints :: (Members '[Fresh] r) => NormRefType -> [ConstraintFor] -> MaybeT (Sem r) NormRefType
+addConstraints :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> [ConstraintFor] -> MaybeT (Sem r) NormRefType
addConstraints = foldM addConstraint
-addConstraint :: (Members '[Fresh] r) => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType
+addConstraint :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType
addConstraint nref@(_, cns) (x, c) = do
breakIf $ any (conflictsWith c) (onVar x cns)
addConstraintHelper nref (lookupVar x cns, c)
-addConstraintHelper :: (Members '[Fresh] r) => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType
+addConstraintHelper :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> ConstraintFor -> MaybeT (Sem r) NormRefType
addConstraintHelper nref@(ctx, cns) cf@(origX, c) = case c of
--- Equation (10)
CMatch k args -> do
@@ -126,17 +128,19 @@ substituteVarIDs y x = map (\(var, c) -> (subst var, c))
-- This function tests if this is true
-- NOTE(colin): we may eventually have type constraints
-- and we would need to worry pulling them from nref here
-inhabited :: (Members '[Fresh] r) => NormRefType -> TI.TypedVar -> Sem r Bool
-inhabited n var = case TI.tyDataCons . TI.getType $ var of
- Nothing -> return True -- assume opaque types are inhabited
- Just constructors -> do
- or <$> mapM (instantiate n var) constructors
+inhabited :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> TI.TypedVar -> Sem r Bool
+inhabited n var = do
+ tyCtx <- ask @Ty.TyDefCtx
+ case TI.tyDataCons (TI.getType var) tyCtx of
+ Nothing -> return True -- assume opaque types are inhabited
+ Just constructors -> do
+ or <$> mapM (instantiate n var) constructors
-- Attempts to "instantiate" a match of the dataconstructor k on x
-- If we can add the MatchDataCon constraint to the normalized refinement
-- type without contradiction (a Nothing value),
-- then x is inhabited by k and we return true
-instantiate :: (Members '[Fresh] r) => NormRefType -> TI.TypedVar -> TI.DataCon -> Sem r Bool
+instantiate :: Members '[Fresh, Reader Ty.TyDefCtx] r => NormRefType -> TI.TypedVar -> TI.DataCon -> Sem r Bool
instantiate (ctx, cns) var k = do
args <- TI.newVars $ TI.dcTypes k
let attempt = (ctx `addVars` args, cns) `addConstraint` (var, CMatch k args)
diff --git a/src/Disco/Exhaustiveness/TypeInfo.hs b/src/Disco/Exhaustiveness/TypeInfo.hs
index 462158a2..1fd098dc 100644
--- a/src/Disco/Exhaustiveness/TypeInfo.hs
+++ b/src/Disco/Exhaustiveness/TypeInfo.hs
@@ -1,6 +1,7 @@
module Disco.Exhaustiveness.TypeInfo where
import Control.Monad (replicateM)
+import qualified Data.Map as M
import Disco.AST.Typed (ATerm)
import Disco.Effects.Fresh (Fresh, fresh)
import qualified Disco.Types as Ty
@@ -63,32 +64,38 @@ left tl = DataCon {dcIdent = KLeft, dcTypes = [tl]}
right :: Ty.Type -> DataCon
right tr = DataCon {dcIdent = KRight, dcTypes = [tr]}
-{-
-TODO(colin): Fill out the remaining types here
-Remaining:
- TyVar
- , TySkolem
- , TyProp
- , TyBag
- , TySet
- , TyGraph
- , TyMap
- , TyUser
-Impossible:
- , (:->:)
--}
-tyDataCons :: Ty.Type -> Maybe [DataCon]
-tyDataCons (a Ty.:*: b) = Just [pair a b]
-tyDataCons (l Ty.:+: r) = Just [left l, right r]
-tyDataCons t@(Ty.TyList a) = Just [cons a t, nil]
-tyDataCons Ty.TyVoid = Just []
-tyDataCons Ty.TyUnit = Just [unit]
-tyDataCons Ty.TyBool = Just [bool True, bool False]
-tyDataCons Ty.TyN = Nothing
-tyDataCons Ty.TyZ = Nothing
-tyDataCons Ty.TyF = Nothing
-tyDataCons Ty.TyQ = Nothing
-tyDataCons Ty.TyC = Nothing
+-- TODO(colin): ask yorgey, make sure I've done this correctly
+-- If I have, and this is enough, I can remove all mentions
+-- of type equality constraints in Constraint.hs,
+-- the lookup here will have handled that behavoir already
+tyDataCons :: Ty.Type -> Ty.TyDefCtx -> Maybe [DataCon]
+tyDataCons (Ty.TyUser name args) ctx = case M.lookup name ctx of
+ Nothing -> error $ "Type definition not found for: " ++ show name
+ Just (Ty.TyDefBody _argNames typeCon) -> tyDataCons (typeCon args) ctx
+tyDataCons (a Ty.:*: b) _ = Just [pair a b]
+tyDataCons (l Ty.:+: r) _ = Just [left l, right r]
+tyDataCons t@(Ty.TyList a) _ = Just [cons a t, nil]
+tyDataCons Ty.TyVoid _ = Just []
+tyDataCons Ty.TyUnit _ = Just [unit]
+tyDataCons Ty.TyBool _ = Just [bool True, bool False]
+tyDataCons Ty.TyN _ = Nothing
+tyDataCons Ty.TyZ _ = Nothing
+tyDataCons Ty.TyF _ = Nothing
+tyDataCons Ty.TyQ _ = Nothing
+tyDataCons Ty.TyC _ = Nothing
+tyDataCons (_ Ty.:->: _) _ = error "Functions not allowed in patterns."
+tyDataCons (Ty.TySet _) _ = error "Sets not allowed in patterns."
+tyDataCons (Ty.TyBag _) _ = error "Bags not allowed in patterns."
+-- I'm unsure about these two.
+-- They may come up when doing generic stuff,
+-- I haven't ecnountered them so far
+-- TODO(colin): ask Yorgey about this
+tyDataCons (Ty.TyVar _) _ = error "Encountered type var in pattern"
+tyDataCons (Ty.TySkolem _) _ = error "Encountered skolem in pattern"
+-- Unsure about these as well
+tyDataCons (Ty.TyProp) _ = error "Propositions not allowed in patterns."
+tyDataCons (Ty.TyMap _ _) _ = error "Maps not allowed in patterns."
+tyDataCons (Ty.TyGraph _) _ = error "Graph not allowed in patterns."
newName :: (Member Fresh r) => Sem r (Name ATerm)
newName = fresh $ s2n ""