Skip to content

Commit

Permalink
Annotate ASTs with types at every node (#991)
Browse files Browse the repository at this point in the history
- Closes #990

Values of type `Syntax` are as before: parsed syntax, with each node annotated with `SrcLoc`.

Values of type `Syntax' Polytype`, however, have each node annotated with *both* a `SrcLoc` *and* a `Polytype`.  (`Syntax` is really just a synonym for `Syntax' ()`.)

Type inference takes a `Syntax` and outputs a `TModule`, which now contains a `Syntax' Polytype`, in other words, a new version of the AST where every node has been annotated with the inferred type of the subterm rooted there.

---

Why is this useful?
1. It will enable us to do type-specific elaboration/rewriting.  For example I think this will allow us to solve #681 , because we only want to apply a rewrite to variables with a command type.
2. It makes type information for any specific subterm easily available.  For example I hope we will be able to use this to enhance the `OnHover` LSP handler, e.g. to show the type of the term under the mouse.

I imagine the code changes might look kind of intimidating but I don't think it's really that bad once you understand what is going on, so I'm happy to answer any questions or explain anything.
  • Loading branch information
byorgey authored Jan 25, 2023
1 parent 07673d1 commit 1678b49
Show file tree
Hide file tree
Showing 18 changed files with 557 additions and 326 deletions.
26 changes: 21 additions & 5 deletions src/Swarm/Game/CESK.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ module Swarm.Game.CESK (
finalValue,
) where

import Control.Lens ((^.))
import Control.Lens.Combinators (pattern Empty)
import Data.Aeson (FromJSON, ToJSON)
import Data.IntMap.Strict (IntMap)
Expand All @@ -96,6 +97,7 @@ import Swarm.Game.Exception
import Swarm.Game.Value as V
import Swarm.Game.World (WorldUpdate (..))
import Swarm.Language.Context
import Swarm.Language.Module
import Swarm.Language.Pipeline
import Swarm.Language.Pretty
import Swarm.Language.Requirement (ReqCtx)
Expand Down Expand Up @@ -284,11 +286,25 @@ initMachine t e s = initMachine' t e s []

-- | Like 'initMachine', but also take an explicit starting continuation.
initMachine' :: ProcessedTerm -> Env -> Store -> Cont -> CESK
initMachine' (ProcessedTerm t (Module (Forall _ (TyCmd _)) ctx) _ reqCtx) e s k =
case ctx of
Empty -> In t e s (FExec : k)
_ -> In t e s (FExec : FLoadEnv ctx reqCtx : k)
initMachine' (ProcessedTerm t _ _ _) e s k = In t e s k
initMachine' (ProcessedTerm (Module t' ctx) _ reqCtx) e s k =
case t' ^. sType of
-- If the starting term has a command type...
Forall _ (TyCmd _) ->
case ctx of
-- ...but doesn't contain any definitions, just create a machine
-- that will evaluate it and then execute it.
Empty -> In t e s (FExec : k)
-- Or, if it does contain definitions, then load the resulting
-- context after executing it.
_ -> In t e s (FExec : FLoadEnv ctx reqCtx : k)
-- Otherwise, for a term with a non-command type, just
-- create a machine to evaluate it.
_ -> In t e s k
where
-- Erase all type and SrcLoc annotations from the term before
-- putting it in the machine state, since those are irrelevant at
-- runtime.
t = eraseS t'

-- | Cancel the currently running computation.
cancel :: CESK -> CESK
Expand Down
2 changes: 1 addition & 1 deletion src/Swarm/Game/State.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ import Swarm.Language.Capability (constCaps)
import Swarm.Language.Context qualified as Ctx
import Swarm.Language.Pipeline (ProcessedTerm)
import Swarm.Language.Pipeline.QQ (tmQ)
import Swarm.Language.Syntax (Const, Term (TText), allConst)
import Swarm.Language.Syntax (Const, Term' (TText), allConst)
import Swarm.Language.Typed (Typed (Typed))
import Swarm.Language.Types
import Swarm.TUI.Model.Achievement.Attainment
Expand Down
2 changes: 1 addition & 1 deletion src/Swarm/Game/Step.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ execConst c vs s k = do

case mt of
Nothing -> return $ Out VUnit s k
Just t@(ProcessedTerm _ _ _ reqCtx) -> do
Just t@(ProcessedTerm _ _ reqCtx) -> do
-- Add the reqCtx from the ProcessedTerm to the current robot's defReqs.
-- See #827 for an explanation of (1) why this is needed, (2) why
-- it's slightly technically incorrect, and (3) why it is still way
Expand Down
2 changes: 1 addition & 1 deletion src/Swarm/Game/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ valueToTerm (VClo x t e) =
M.foldrWithKey
(\y v -> TLet False y Nothing (valueToTerm v))
(TLam x Nothing t)
(M.restrictKeys (unCtx e) (S.delete x (setOf fv t)))
(M.restrictKeys (unCtx e) (S.delete x (setOf freeVarsV (Syntax' NoLoc t ()))))
valueToTerm (VCApp c vs) = foldl' TApp (TConst c) (reverse (map valueToTerm vs))
valueToTerm (VDef r x t _) = TDef r x Nothing t
valueToTerm (VResult v _) = valueToTerm v
Expand Down
65 changes: 39 additions & 26 deletions src/Swarm/Language/Elaborate.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module : Swarm.Language.Elaborate
-- Copyright : Brent Yorgey
Expand All @@ -8,40 +10,51 @@
-- Term elaboration which happens after type checking.
module Swarm.Language.Elaborate where

import Control.Lens (transform, (%~))
import Control.Lens (transform, (%~), (^.))
import Swarm.Language.Syntax
import Swarm.Language.Types

-- | Perform some elaboration / rewriting on a fully type-annotated
-- term, given its top-level type. This currently performs such
-- operations as rewriting @if@ expressions and recursive let
-- expressions to use laziness appropriately. In theory it could
-- also perform rewriting for overloaded constants depending on the
-- actual type they are used at, but currently that sort of thing
-- tends to make type inference fall over.
elaborate :: Term -> Term
-- term. This currently performs such operations as rewriting @if@
-- expressions and recursive let expressions to use laziness
-- appropriately. In theory it could also perform rewriting for
-- overloaded constants depending on the actual type they are used
-- at, but currently that sort of thing tends to make type inference
-- fall over.
elaborate :: Syntax' Polytype -> Syntax' Polytype
elaborate =
-- Wrap all *free* variables in 'Force'. Free variables must be
-- referring to a previous definition, which are all wrapped in
-- 'TDelay'.
(fvT %~ TApp (TConst Force))
(freeVarsS %~ \s -> Syntax' (s ^. sLoc) (SApp sForce s) (s ^. sType))
-- Now do additional rewriting on all subterms.
. transform rewrite
where
-- For recursive let bindings, rewrite any occurrences of x to
-- (force x). When interpreting t1, we will put a binding (x |->
-- delay t1) in the context.
rewrite (TLet True x ty t1 t2) = TLet True x ty (wrapForce x t1) (wrapForce x t2)
-- Rewrite any recursive occurrences of x inside t1 to (force x).
-- When a TDef is encountered at runtime its body will immediately
-- be wrapped in a VDelay. However, to make this work we also need
-- to wrap all free variables in any term with 'force' --- since
-- any such variables must in fact refer to things previously
-- bound by 'def'.
rewrite (TDef True x ty t1) = TDef True x ty (mapFree1 x (TApp (TConst Force)) t1)
-- Rewrite @f $ x@ to @f x@.
rewrite (TApp (TApp (TConst AppF) r) l) = TApp r l
-- Leave any other subterms alone.
rewrite t = t
rewrite :: Syntax' Polytype -> Syntax' Polytype
rewrite (Syntax' l t ty) = Syntax' l (rewriteTerm t) ty

rewriteTerm :: Term' Polytype -> Term' Polytype
rewriteTerm = \case
-- For recursive let bindings, rewrite any occurrences of x to
-- (force x). When interpreting t1, we will put a binding (x |->
-- delay t1) in the context.
SLet True x ty t1 t2 -> SLet True x ty (wrapForce (lvVar x) t1) (wrapForce (lvVar x) t2)
-- Rewrite any recursive occurrences of x inside t1 to (force x).
-- When a TDef is encountered at runtime its body will immediately
-- be wrapped in a VDelay. However, to make this work we also need
-- to wrap all free variables in any term with 'force' --- since
-- any such variables must in fact refer to things previously
-- bound by 'def'.
SDef True x ty t1 -> SDef True x ty (wrapForce (lvVar x) t1)
-- Rewrite @f $ x@ to @f x@.
SApp (Syntax' _ (SApp (Syntax' _ (TConst AppF) _) l) _) r -> SApp l r
-- Leave any other subterms alone.
t -> t

wrapForce :: Var -> Syntax' Polytype -> Syntax' Polytype
wrapForce x = mapFreeS x (\s@(Syntax' l _ ty) -> Syntax' l (SApp sForce s) ty)

-- Note, TyUnit is not the right type, but I don't want to bother

wrapForce :: Var -> Term -> Term
wrapForce x = mapFree1 x (TApp (TConst Force))
sForce :: Syntax' Polytype
sForce = Syntax' NoLoc (TConst Force) (Forall ["a"] (TyDelay (TyVar "a") :->: TyVar "a"))
51 changes: 51 additions & 0 deletions src/Swarm/Language/Module.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
-- |
-- Module : Swarm.Language.Module
-- Copyright : Brent Yorgey
-- Maintainer : [email protected]
--
-- SPDX-License-Identifier: BSD-3-Clause
--
-- A 'Module' packages together a type-annotated syntax tree along
-- with a context of top-level definitions.
module Swarm.Language.Module (
-- * Modules
Module (..),
TModule,
UModule,
trivMod,
) where

import Data.Data (Data)
import Data.Yaml (FromJSON, ToJSON)
import GHC.Generics (Generic)
import Swarm.Language.Context (Ctx, empty)
import Swarm.Language.Syntax (Syntax')
import Swarm.Language.Types (Polytype, UPolytype, UType)

------------------------------------------------------------
-- Modules
------------------------------------------------------------

-- | A module generally represents the result of performing type
-- inference on a top-level expression, which in particular can
-- contain definitions ('Swarm.Language.Syntax.TDef'). A module
-- contains the type-annotated AST of the expression itself, as well
-- as the context giving the types of any defined variables.
data Module s t = Module {moduleAST :: Syntax' s, moduleCtx :: Ctx t}
deriving (Show, Eq, Functor, Data, Generic, FromJSON, ToJSON)

-- | A 'TModule' is the final result of the type inference process on
-- an expression: we get a polytype for the expression, and a
-- context of polytypes for the defined variables.
type TModule = Module Polytype Polytype

-- | A 'UModule' represents the type of an expression at some
-- intermediate stage during the type inference process. We get a
-- 'UType' (/not/ a 'UPolytype') for the expression, which may
-- contain some free unification or type variables, as well as a
-- context of 'UPolytype's for any defined variables.
type UModule = Module UType UPolytype

-- | The trivial module for a given AST, with the empty context.
trivMod :: Syntax' s -> Module s t
trivMod t = Module t empty
29 changes: 16 additions & 13 deletions src/Swarm/Language/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ module Swarm.Language.Parse (
getLocRange,
) where

import Control.Lens (view, (^.))
import Control.Monad.Combinators.Expr
import Control.Monad.Reader
import Data.Bifunctor
Expand Down Expand Up @@ -277,7 +278,7 @@ parseTermAtom =
<$> (reserved "def" *> locIdentifier)
<*> optional (symbol ":" *> parsePolytype)
<*> (symbol "=" *> parseTerm <* reserved "end")
<|> parens (mkTuple <$> (parseTerm `sepBy` symbol ","))
<|> parens (view sTerm . mkTuple <$> (parseTerm `sepBy` symbol ","))
)
-- Potential syntax for explicitly requesting memoized delay.
-- Perhaps we will not need this in the end; see the discussion at
Expand All @@ -289,20 +290,22 @@ parseTermAtom =
<|> parseLoc (SDelay SimpleDelay <$> braces parseTerm)
<|> parseLoc (ask >>= (guard . (== AllowAntiquoting)) >> parseAntiquotation)

mkTuple :: [Syntax] -> Term
mkTuple [] = TUnit
mkTuple [STerm x] = x
mkTuple (x : xs) = SPair x (STerm (mkTuple xs))
mkTuple :: [Syntax] -> Syntax
mkTuple [] = Syntax NoLoc TUnit -- should never happen
mkTuple [x] = x
mkTuple (x : xs) = let r = mkTuple xs in loc x r $ SPair x r
where
loc a b = Syntax $ (a ^. sLoc) <> (b ^. sLoc)

-- | Construct an 'SLet', automatically filling in the Boolean field
-- indicating whether it is recursive.
sLet :: LocVar -> Maybe Polytype -> Syntax -> Syntax -> Term
sLet x ty t1 = SLet (lvVar x `S.member` setOf fv (sTerm t1)) x ty t1
sLet x ty t1 = SLet (lvVar x `S.member` setOf freeVarsV t1) x ty t1

-- | Construct an 'SDef', automatically filling in the Boolean field
-- indicating whether it is recursive.
sDef :: LocVar -> Maybe Polytype -> Syntax -> Term
sDef x ty t = SDef (lvVar x `S.member` setOf fv (sTerm t)) x ty t
sDef x ty t = SDef (lvVar x `S.member` setOf freeVarsV t) x ty t

parseAntiquotation :: Parser Term
parseAntiquotation =
Expand All @@ -318,9 +321,9 @@ mkBindChain stmts = case last stmts of
Binder x _ -> return $ foldr mkBind (STerm (TApp (TConst Return) (TVar (lvVar x)))) stmts
BareTerm t -> return $ foldr mkBind t (init stmts)
where
mkBind (BareTerm t1) t2 = loc t1 t2 $ SBind Nothing t1 t2
mkBind (Binder x t1) t2 = loc t1 t2 $ SBind (Just x) t1 t2
loc a b = Syntax $ sLoc a <> sLoc b
mkBind (BareTerm t1) t2 = loc Nothing t1 t2 $ SBind Nothing t1 t2
mkBind (Binder x t1) t2 = loc (Just x) t1 t2 $ SBind (Just x) t1 t2
loc mx a b = Syntax $ maybe NoLoc lvSrcLoc mx <> (a ^. sLoc) <> (b ^. sLoc)

data Stmt
= BareTerm Syntax
Expand All @@ -347,7 +350,7 @@ fixDefMissingSemis term =
[] -> term
defs -> foldr1 mkBind defs
where
mkBind t1 t2 = Syntax (sLoc t1 <> sLoc t2) $ SBind Nothing t1 t2
mkBind t1 t2 = Syntax ((t1 ^. sLoc) <> (t2 ^. sLoc)) $ SBind Nothing t1 t2
nestedDefs term' acc = case term' of
def@(Syntax _ SDef {}) -> def : acc
(Syntax _ (SApp nestedTerm def@(Syntax _ SDef {}))) -> nestedDefs nestedTerm (def : acc)
Expand All @@ -370,7 +373,7 @@ parseExpr = fixDefMissingSemis <$> makeExprParser parseTermAtom table
exprLoc2 :: Parser (Syntax -> Syntax -> Term) -> Parser (Syntax -> Syntax -> Syntax)
exprLoc2 p = do
(l, f) <- parseLocG p
pure $ \s1 s2 -> Syntax (l <> sLoc s1 <> sLoc s2) $ f s1 s2
pure $ \s1 s2 -> Syntax (l <> (s1 ^. sLoc) <> (s2 ^. sLoc)) $ f s1 s2

-- | Precedences and parsers of binary operators.
--
Expand Down Expand Up @@ -413,7 +416,7 @@ unOps = Map.unionsWith (++) $ mapMaybe unOpToTuple allConst
exprLoc1 :: Parser (Syntax -> Term) -> Parser (Syntax -> Syntax)
exprLoc1 p = do
(l, f) <- parseLocG p
pure $ \s -> Syntax (l <> sLoc s) $ f s
pure $ \s -> Syntax (l <> s ^. sLoc) $ f s

operatorString :: Text -> Parser Text
operatorString n = (lexeme . try) (string n <* notFollowedBy operatorSymbol)
Expand Down
17 changes: 10 additions & 7 deletions src/Swarm/Language/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ module Swarm.Language.Pipeline (
showTypeErrorPos,
) where

import Control.Lens ((^.))
import Data.Bifunctor (first)
import Data.Data (Data)
import Data.Text (Text)
import Data.Yaml as Y
import GHC.Generics (Generic)
import Swarm.Language.Context
import Swarm.Language.Elaborate
import Swarm.Language.Module
import Swarm.Language.Parse
import Swarm.Language.Pretty
import Swarm.Language.Requirement
Expand All @@ -40,10 +42,8 @@ import Witch
-- pipeline. Put a 'Term' in, and get one of these out.
data ProcessedTerm
= ProcessedTerm
Term
-- ^ The elaborated term
TModule
-- ^ The type of the term (and of any embedded definitions)
-- ^ The elaborated + type-annotated term, plus types of any embedded definitions
Requirements
-- ^ Requirements of the term
ReqCtx
Expand All @@ -60,7 +60,7 @@ instance FromJSON ProcessedTerm where
Right (Just pt) -> return pt

instance ToJSON ProcessedTerm where
toJSON (ProcessedTerm t _ _ _) = String $ prettyText t
toJSON (ProcessedTerm t _ _) = String $ prettyText (moduleAST t)

-- | Given a 'Text' value representing a Swarm program,
--
Expand Down Expand Up @@ -104,6 +104,9 @@ showTypeErrorPos code te = (minusOne start, minusOne end, msg)
-- | Like 'processTerm'', but use a term that has already been parsed.
processParsedTerm' :: TCtx -> ReqCtx -> Syntax -> Either TypeErr ProcessedTerm
processParsedTerm' ctx capCtx t = do
ty <- inferTop ctx t
let (caps, capCtx') = requirements capCtx (sTerm t)
return $ ProcessedTerm (elaborate (sTerm t)) ty caps capCtx'
m <- inferTop ctx t
let (caps, capCtx') = requirements capCtx (t ^. sTerm)
return $ ProcessedTerm (elaborateModule m) caps capCtx'

elaborateModule :: TModule -> TModule
elaborateModule (Module ast ctx) = Module (elaborate ast) ctx
3 changes: 2 additions & 1 deletion src/Swarm/Language/Pipeline/QQ.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Swarm.Language.Parse
import Swarm.Language.Pipeline
import Swarm.Language.Pretty (prettyText)
import Swarm.Language.Syntax
import Swarm.Language.Types (Polytype)
import Swarm.Util (liftText)
import Witch (from)

Expand Down Expand Up @@ -48,7 +49,7 @@ quoteTermExp s = do
Left errMsg -> fail $ from $ prettyText errMsg
Right ptm -> dataToExpQ ((fmap liftText . cast) `extQ` antiTermExp) ptm

antiTermExp :: Term -> Maybe TH.ExpQ
antiTermExp :: Term' Polytype -> Maybe TH.ExpQ
antiTermExp (TAntiText v) =
Just $ TH.appE (TH.conE (TH.mkName "TText")) (TH.varE (TH.mkName (from v)))
antiTermExp (TAntiInt v) =
Expand Down
7 changes: 5 additions & 2 deletions src/Swarm/Language/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ instance PrettyPrec Capability where
instance PrettyPrec Const where
prettyPrec p c = pparens (p > fixity (constInfo c)) $ pretty . syntax . constInfo $ c

instance PrettyPrec (Syntax' ty) where
prettyPrec p = prettyPrec p . eraseS

instance PrettyPrec Term where
prettyPrec _ TUnit = "()"
prettyPrec p (TConst c) = prettyPrec p c
Expand All @@ -122,8 +125,8 @@ instance PrettyPrec Term where
prettyPrec _ (TBool b) = bool "false" "true" b
prettyPrec _ (TRobot r) = "<a" <> pretty r <> ">"
prettyPrec _ (TRef r) = "@" <> pretty r
prettyPrec p (TRequireDevice d) = pparens (p > 10) $ "require" <+> ppr (TText d)
prettyPrec p (TRequire n e) = pparens (p > 10) $ "require" <+> pretty n <+> ppr (TText e)
prettyPrec p (TRequireDevice d) = pparens (p > 10) $ "require" <+> ppr @Term (TText d)
prettyPrec p (TRequire n e) = pparens (p > 10) $ "require" <+> pretty n <+> ppr @Term (TText e)
prettyPrec _ (TVar s) = pretty s
prettyPrec _ (TDelay _ t) = braces $ ppr t
prettyPrec _ t@TPair {} = prettyTuple t
Expand Down
Loading

0 comments on commit 1678b49

Please sign in to comment.