Skip to content

Commit

Permalink
Make min and max prefix functions instead of infix operators (#383)
Browse files Browse the repository at this point in the history
Min and max are now primitive, built-in prefix functions.

Eventually we should move them into the standard library, once we implement #179 .
  • Loading branch information
byorgey authored May 23, 2024
1 parent eae1326 commit 647e5cf
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 174 deletions.
2 changes: 1 addition & 1 deletion example/demo.disco
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ type P(a) = a + P(a) * P(a)

height : P(a) -> N
height (left(_)) = 0
height (right(l, r)) = 1 + (height l) max (height r)
height (right(l, r)) = 1 + max(height l, height r)
2 changes: 1 addition & 1 deletion example/tree.disco
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ treeSize : Tree -> N
treeSize(t) = treeFold(0, \(x,l,r). 1 + l + r, t)

treeHeight : Tree -> N
treeHeight(t) = treeFold(0, \(x,l,r). 1 + l max r, t)
treeHeight(t) = treeFold(0, \(x,l,r). 1 + max(l,r), t)
45 changes: 25 additions & 20 deletions src/Disco/Desugar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,26 @@ desugarCList2B p ty cts b = do
)
return $ mkLambda ty [c] body

-- | Desugar the @min@ and @max@ functions into conditional expressions.
desugarMinMax :: Member Fresh r => Prim -> Type -> Sem r DTerm
desugarMinMax m ty = do
p <- fresh (string2Name "p")
a <- fresh (string2Name "a")
b <- fresh (string2Name "b")
body <-
desugarTerm $
ATCase
ty
[ bind
(toTelescope [AGPat (embed (atVar ty p)) (APTup (ty :*: ty) [APVar ty a, APVar ty b])])
$ ATCase
ty
[ atVar ty (if m == PrimMin then a else b) <==. [tif (atVar ty a <. atVar ty b)]
, atVar ty (if m == PrimMin then b else a) <==. []
]
]
return $ mkLambda ((ty :*: ty) :->: ty) [p] body

-- | Desugar a typechecked term.
desugarTerm :: Member Fresh r => ATerm -> Sem r DTerm
desugarTerm (ATVar ty x) = return $ DTVar ty (coerce x)
Expand All @@ -274,6 +294,8 @@ desugarTerm (ATPrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop))
| bopDesugars ty1 ty2 resTy bop = desugarPrimBOp ty1 ty2 resTy bop
desugarTerm (ATPrim ty@(TyList cts :->: TyBag b) PrimC2B) = desugarCList2B PrimC2B ty cts b
desugarTerm (ATPrim ty@(TyList cts :->: TyBag b) PrimUC2B) = desugarCList2B PrimUC2B ty cts b
desugarTerm (ATPrim (_ :->: ty) PrimMin) = desugarMinMax PrimMin ty
desugarTerm (ATPrim (_ :->: ty) PrimMax) = desugarMinMax PrimMax ty
desugarTerm (ATPrim ty x) = return $ DTPrim ty x
desugarTerm ATUnit = return DTUnit
desugarTerm (ATBool ty b) = return $ DTBool ty b
Expand Down Expand Up @@ -345,8 +367,6 @@ bopDesugars _ _ _ bop =
, Gt
, Leq
, Geq
, Min
, Max
, IDiv
, Sub
, SSub
Expand Down Expand Up @@ -424,21 +444,6 @@ desugarBinApp _ Neq t1 t2 = desugarTerm $ tnot (t1 ==. t2)
desugarBinApp _ Gt t1 t2 = desugarTerm $ t2 <. t1
desugarBinApp _ Leq t1 t2 = desugarTerm $ tnot (t2 <. t1)
desugarBinApp _ Geq t1 t2 = desugarTerm $ tnot (t1 <. t2)
-- XXX sharing!
desugarBinApp ty Min t1 t2 =
desugarTerm $
ATCase
ty
[ t1 <==. [tif (t1 <. t2)]
, t2 <==. []
]
desugarBinApp ty Max t1 t2 =
desugarTerm $
ATCase
ty
[ t1 <==. [tif (t2 <. t1)]
, t2 <==. []
]
-- t1 // t2 ==> floor (t1 / t2)
desugarBinApp resTy IDiv t1 t2 =
desugarTerm $
Expand Down Expand Up @@ -498,9 +503,9 @@ desugarBinApp ty op t1 t2
, t2
]
where
mergeOp _ Inter = PrimBOp Min
mergeOp _ Inter = PrimMin
mergeOp _ Diff = PrimBOp SSub
mergeOp (TySet _) Union = PrimBOp Max
mergeOp (TySet _) Union = PrimMax
mergeOp (TyBag _) Union = PrimBOp Add
mergeOp _ _ = error $ "Impossible! mergeOp " ++ show ty ++ " " ++ show op

Expand All @@ -514,7 +519,7 @@ desugarBinApp _ Subset t1 t2 =
(ATPrim (ty :*: ty :->: TyBool) (PrimBOp Eq))
[ tapps
(ATPrim ((TyN :*: TyN :->: TyN) :*: ty :*: ty :->: ty) PrimMerge)
[ ATPrim (TyN :*: TyN :->: TyN) (PrimBOp Max)
[ ATPrim (TyN :*: TyN :->: TyN) PrimMax
, t1
, t2
]
Expand Down
2 changes: 2 additions & 0 deletions src/Disco/Doc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ primDoc =
, PrimFloor ==> "floor(x) is the largest integer which is <= x."
, PrimCeil ==> "ceiling(x) is the smallest integer which is >= x."
, PrimAbs ==> "abs(x) is the absolute value of x. Also written |x|."
, PrimMin ==> "min(x,y) is the minimum of x and y, i.e. whichever is smaller."
, PrimMax ==> "max(x,y) is the maximum of x and y, i.e. whichever is larger."
, PrimUOp Not ==> "Logical negation: not(true) = false and not(false) = true."
, PrimBOp And ==> "Logical conjunction (and): true /\\ true = true; otherwise x /\\ y = false."
, PrimBOp Or ==> "Logical disjunction (or): false \\/ false = false; otherwise x \\/ y = true."
Expand Down
2 changes: 0 additions & 2 deletions src/Disco/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,6 @@ reservedWords =
, "choose"
, "implies"
, "iff"
, "min"
, "max"
, "union"
, ""
, "intersect"
Expand Down
8 changes: 0 additions & 8 deletions src/Disco/Syntax/Operators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ data BOp
Leq
| -- | Greater than or equal (@>=@)
Geq
| -- | Minimum (@min@)
Min
| -- | Maximum (@max@)
Max
| -- | Logical and (@&&@ / @and@)
And
| -- | Logical or (@||@ / @or@)
Expand Down Expand Up @@ -202,10 +198,6 @@ opTable =
, bopInfo InL Inter ["intersect", ""]
, bopInfo InL Diff ["\\"]
]
,
[ bopInfo InL Min ["min"]
, bopInfo InL Max ["max"]
]
,
[ bopInfo InL Mul ["*"]
, bopInfo InL Div ["/"]
Expand Down
195 changes: 81 additions & 114 deletions src/Disco/Syntax/Prims.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,120 +37,85 @@ import Disco.Syntax.Operators

-- | Primitives, /i.e./ built-in constants.
data Prim where
PrimUOp ::
UOp ->
-- | Unary operator
Prim
PrimBOp ::
BOp ->
-- | Binary operator
Prim
PrimLeft ::
-- | Left injection into a sum type.
Prim
PrimRight ::
-- | Right injection into a sum type.
Prim
PrimSqrt ::
-- | Integer square root (@sqrt@)
Prim
PrimFloor ::
-- | Floor of fractional type (@floor@)
Prim
PrimCeil ::
-- | Ceiling of fractional type (@ceiling@)
Prim
PrimAbs ::
-- | Absolute value (@abs@)
Prim
PrimPower ::
-- | Power set (XXX or bag?)
Prim
PrimList ::
-- | Container -> list conversion
Prim
PrimBag ::
-- | Container -> bag conversion
Prim
PrimSet ::
-- | Container -> set conversion
Prim
PrimB2C ::
-- | bag -> set of counts conversion
Prim
PrimC2B ::
-- | set of counts -> bag conversion
Prim
PrimUC2B ::
-- | unsafe set of counts -> bag conversion
-- that assumes all distinct
Prim
PrimMapToSet ::
-- | Map k v -> Set (k × v)
Prim
PrimSetToMap ::
-- | Set (k × v) -> Map k v
Prim
PrimSummary ::
-- | Get Adjacency list of Graph
Prim
PrimVertex ::
-- | Construct a graph Vertex
Prim
PrimEmptyGraph ::
-- | Empty graph
Prim
PrimOverlay ::
-- | Overlay two Graphs
Prim
PrimConnect ::
-- | Connect Graph to another with directed edges
Prim
PrimInsert ::
-- | Insert into map
Prim
PrimLookup ::
-- | Get value associated with key in map
Prim
PrimEach ::
-- | Each operation for containers
Prim
PrimReduce ::
-- | Reduce operation for containers
Prim
PrimFilter ::
-- | Filter operation for containers
Prim
PrimJoin ::
-- | Monadic join for containers
Prim
PrimMerge ::
-- | Generic merge operation for bags/sets
Prim
PrimIsPrime ::
-- | Efficient primality test
Prim
PrimFactor ::
-- | Factorization
Prim
PrimFrac ::
-- | Turn a rational into a pair (num, denom)
Prim
PrimCrash ::
-- | Crash
Prim
PrimUntil ::
-- | @[x, y, z .. e]@
Prim
PrimHolds ::
-- | Test whether a proposition holds
Prim
PrimLookupSeq ::
-- | Lookup OEIS sequence
Prim
PrimExtendSeq ::
-- | Extend OEIS sequence
Prim
-- | Unary operator
PrimUOp :: UOp -> Prim
-- | Binary operator
PrimBOp :: BOp -> Prim
-- | Left injection into a sum type.
PrimLeft :: Prim
-- | Right injection into a sum type.
PrimRight :: Prim
-- | Integer square root (@sqrt@)
PrimSqrt :: Prim
-- | Floor of fractional type (@floor@)
PrimFloor :: Prim
-- | Ceiling of fractional type (@ceiling@)
PrimCeil :: Prim
-- | Absolute value (@abs@)
PrimAbs :: Prim
-- | Min
PrimMin :: Prim
-- | Max
PrimMax :: Prim
-- | Power set (XXX or bag?)
PrimPower :: Prim
-- | Container -> list conversion
PrimList :: Prim
-- | Container -> bag conversion
PrimBag :: Prim
-- | Container -> set conversion
PrimSet :: Prim
-- | bag -> set of counts conversion
PrimB2C :: Prim
-- | set of counts -> bag conversion
PrimC2B :: Prim
-- | unsafe set of counts -> bag conversion
-- that assumes all distinct
PrimUC2B :: Prim
-- | Map k v -> Set (k × v)
PrimMapToSet :: Prim
-- | Set (k × v) -> Map k v
PrimSetToMap :: Prim
-- | Get Adjacency list of Graph
PrimSummary :: Prim
-- | Construct a graph Vertex
PrimVertex :: Prim
-- | Empty graph
PrimEmptyGraph :: Prim
-- | Overlay two Graphs
PrimOverlay :: Prim
-- | Connect Graph to another with directed edges
PrimConnect :: Prim
-- | Insert into map
PrimInsert :: Prim
-- | Get value associated with key in map
PrimLookup :: Prim
-- | Each operation for containers
PrimEach :: Prim
-- | Reduce operation for containers
PrimReduce :: Prim
-- | Filter operation for containers
PrimFilter :: Prim
-- | Monadic join for containers
PrimJoin :: Prim
-- | Generic merge operation for bags/sets
PrimMerge :: Prim
-- | Efficient primality test
PrimIsPrime :: Prim
-- | Factorization
PrimFactor :: Prim
-- | Turn a rational into a pair (num, denom)
PrimFrac :: Prim
-- | Crash
PrimCrash :: Prim
-- | @[x, y, z .. e]@
PrimUntil :: Prim
-- | Test whether a proposition holds
PrimHolds :: Prim
-- | Lookup OEIS sequence
PrimLookupSeq :: Prim
-- | Extend OEIS sequence
PrimExtendSeq :: Prim
deriving (Show, Read, Eq, Ord, Generic, Alpha, Subst t, Data)

------------------------------------------------------------
Expand Down Expand Up @@ -192,6 +157,8 @@ primTable =
, PrimInfo PrimFloor "floor" True
, PrimInfo PrimCeil "ceiling" True
, PrimInfo PrimAbs "abs" True
, PrimInfo PrimMin "min" True
, PrimInfo PrimMax "max" True
, PrimInfo PrimPower "power" True
, PrimInfo PrimList "list" True
, PrimInfo PrimBag "bag" True
Expand Down
22 changes: 12 additions & 10 deletions src/Disco/Typecheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -885,16 +885,6 @@ typecheck Infer (TPrim prim) = do
inferPrim (PrimBOp Geq) = error "inferPrim Geq should be unreachable"
------------------------------------------------------------

inferPrim (PrimBOp op) | op `elem` [Min, Max] = do
ty <- freshTy
constraint $ CQual QCmp ty
return $ ty :*: ty :->: ty

-- See Note [Pattern coverage] -----------------------------
inferPrim (PrimBOp Min) = error "inferPrim Min should be unreachable"
inferPrim (PrimBOp Max) = error "inferPrim Max should be unreachable"
------------------------------------------------------------

----------------------------------------
-- Special arithmetic functions: fact, sqrt, floor, ceil, abs

Expand All @@ -916,6 +906,18 @@ typecheck Infer (TPrim prim) = do
cAbs argTy resTy `cOr` cSize argTy resTy
return $ argTy :->: resTy

----------------------------------------
-- min/max

inferPrim PrimMin = do
a <- freshTy
constraint $ CQual QCmp a
return $ (a :*: a) :->: a
inferPrim PrimMax = do
a <- freshTy
constraint $ CQual QCmp a
return $ (a :*: a) :->: a

----------------------------------------
-- power set/bag

Expand Down
Loading

0 comments on commit 647e5cf

Please sign in to comment.