Skip to content

Commit

Permalink
[ new ] System.Concurrency.(Linear/Session)
Browse files Browse the repository at this point in the history
  • Loading branch information
gallais committed Jun 3, 2024
1 parent 940d7da commit e115aca
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 10 deletions.
24 changes: 15 additions & 9 deletions libs/linear/Control/Linear/LIO.idr
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module Control.Linear.LIO

import Data.Linear.Notation

||| Like `Monad`, but the action and continuation must be run exactly once
||| to ensure that the computation is linear.
public export
interface LinearBind io where
bindL : (1 _ : io a) -> (1 _ : a -> io b) -> io b
bindL : io a -@ (a -> io b) -@ io b

export
LinearBind IO where
Expand Down Expand Up @@ -71,7 +73,7 @@ RunCont Unrestricted t b = t -> b
-- concrete type of the continuation is.
runK : {use : _} ->
LinearBind io =>
(1 _ : L io {use} a) -> (1 _ : RunCont use a (io b)) -> io b
L io {use} a -@ RunCont use a (io b) -@ io b
runK (Pure0 x) k = k x
runK (Pure1 x) k = k x
runK (PureW x) k = k x
Expand All @@ -86,7 +88,7 @@ runK (Bind {u_act = Unrestricted} act next) k = runK act (\x => runK (next x) k)
||| underlying context
export
run : Applicative io => LinearBind io =>
(1 _ : L io a) -> io a
L io a -@ io a
run prog = runK prog pure

export
Expand All @@ -110,12 +112,12 @@ export
export %inline
(>>=) : {u_act : _} ->
LinearBind io =>
(1 _ : L io {use=u_act} a) ->
(1 _ : ContType io u_act u_k a b) -> L io {use=u_k} b
L io {use=u_act} a -@
ContType io u_act u_k a b -@ L io {use=u_k} b
(>>=) = Bind

export
delay : {u_act : _} -> (1 _ : L io {use=u_k} b) -> ContType io u_act u_k () b
delay : {u_act : _} -> L io {use=u_k} b -@ ContType io u_act u_k () b
delay mb = case u_act of
None => \ _ => mb
Linear => \ () => mb
Expand All @@ -124,18 +126,22 @@ delay mb = case u_act of
export %inline
(>>) : {u_act : _} ->
LinearBind io =>
(1 _ : L io {use=u_act} ()) ->
(1 _ : L io {use=u_k} b) -> L io {use=u_k} b
L io {use=u_act} () -@
L io {use=u_k} b -@ L io {use=u_k} b
ma >> mb = ma >>= delay mb

export %inline
pure0 : (0 x : a) -> L io {use=0} a
pure0 = Pure0

export %inline
pure1 : (1 x : a) -> L io {use=1} a
pure1 : a -@ L io {use=1} a
pure1 = Pure1

export %inline
bang : L IO t -@ L1 IO (!* t)
bang io = io >>= \ a => pure1 (MkBang a)

export
(LinearBind io, HasLinearIO io) => HasLinearIO (L io) where
liftIO1 p = Action (liftIO1 p)
Expand Down
35 changes: 35 additions & 0 deletions libs/linear/System/Concurrency/Linear.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module System.Concurrency.Linear

import Control.Linear.LIO

import Data.Linear.Notation
import System.Concurrency

||| Run two linear computations in parallel and return the results.
export
par1 : L1 IO a -@ L1 IO b -@ L1 IO (LPair a b)
par1 x y
= do aChan <- makeChannel
bChan <- makeChannel
aId <- liftIO1 $ fork $ withChannel aChan x
bId <- liftIO1 $ fork $ withChannel bChan y
a <- channelGet aChan
b <- channelGet bChan
pure1 (a # b)

where

-- This unsafe implementation temporarily bypasses the linearity checker.
-- However `par`'s implementation does not duplicate the values
-- and the type of `par` ensures that client code is not allowed to either!
withChannel : Channel t -> L1 IO t -@ IO ()
withChannel ch = assert_linear $ \ act => do
a <- LIO.run (act >>= assert_linear pure)
channelPut ch a

||| Run two unrestricted computations in parallel and return the results.
export
par : L IO a -@ L IO b -@ L IO (a, b)
par x y = do
(MkBang a # MkBang b) <- par1 (bang x) (bang y)
pure (a, b)
226 changes: 226 additions & 0 deletions libs/linear/System/Concurrency/Session.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
module System.Concurrency.Session

import Control.Linear.LIO

import Data.Linear.Notation
import Data.List.AtIndex
import Data.Nat

import Data.OpenUnion
import System
import System.File
import System.Concurrency as Threads
import System.Concurrency.Linear

import Language.Reflection

%default total

------------------------------------------------------------------------
-- Session types

namespace Session

||| A session type describes the interactions one thread may have with
||| another over a shared bidirectional channel: it may send or receive
||| values of arbitrary types, or be done communicating.
public export
data Session : Type where
Send : (ty : Type) -> (s : Session) -> Session
Recv : (ty : Type) -> (s : Session) -> Session
End : Session

||| Dual describes how the other party to the communication sees the
||| interactions: our sends become their receives and vice-versa.
public export
Dual : Session -> Session
Dual (Send ty s) = Recv ty (Dual s)
Dual (Recv ty s) = Send ty (Dual s)
Dual End = End

||| Duality is involutive: the dual of my dual is me
export
dualInvolutive : (s : Session) -> Dual (Dual s) === s
dualInvolutive (Send ty s) = cong (Send ty) (dualInvolutive s)
dualInvolutive (Recv ty s) = cong (Recv ty) (dualInvolutive s)
dualInvolutive End = Refl

||| We can collect the list of types that will be sent over the
||| course of a session by walking down its description
||| This definition is purely internal and will not show up in
||| the library's interface.
SendTypes : Session -> List Type
SendTypes (Send ty s) = ty :: SendTypes s
SendTypes (Recv ty s) = SendTypes s
SendTypes End = []

||| We can collect the list of types that will be received over
||| the course of a session by walking down its description
||| This definition is purely internal and will not show up in
||| the library's interface.
RecvTypes : Session -> List Type
RecvTypes (Send ty s) = RecvTypes s
RecvTypes (Recv ty s) = ty :: RecvTypes s
RecvTypes End = []

||| The types received by my dual are exactly the ones I am sending
||| This definition is purely internal and will not show up in
||| the library's interface.
RecvDualTypes : (s : Session) -> RecvTypes (Dual s) === SendTypes s
RecvDualTypes (Send ty s) = cong (ty ::) (RecvDualTypes s)
RecvDualTypes (Recv ty s) = RecvDualTypes s
RecvDualTypes End = Refl

||| The types sent by my dual are exactly the ones I receive
||| This definition is purely internal and will not show up in
||| the library's interface.
SendDualTypes : (s : Session) -> SendTypes (Dual s) === RecvTypes s
SendDualTypes (Send ty s) = SendDualTypes s
SendDualTypes (Recv ty s) = cong (ty ::) (SendDualTypes s)
SendDualTypes End = Refl

namespace Seen

||| The inductive family (Seen m n f) states that the function `f`
||| was obtained by composing an interleaving of `m` receiving
||| steps and `n` sending ones.
public export
data Seen : Nat -> Nat -> (Session -> Session) -> Type where
None : Seen 0 0 Prelude.id
Recv : (ty : Type) -> Seen m n f -> Seen (S m) n (f . Recv ty)
Send : (ty : Type) -> Seen m n f -> Seen m (S n) (f . Send ty)

||| If we know that `ty` is at index `n` in the list of received types
||| and that `f` is a function defined using an interleaving of steps
||| comprising `m` receiving stepsx then `ty` is at index `m + n` in `f s`.
atRecvIndex : Seen m _ f ->
(s : Session) ->
AtIndex ty (RecvTypes s) n ->
AtIndex ty (RecvTypes (f s)) (m + n)
atRecvIndex None accS accAt = accAt
atRecvIndex (Recv ty s) accS accAt
= rewrite plusSuccRightSucc (pred m) n in
atRecvIndex s (Recv ty accS) (S accAt)
atRecvIndex (Send ty s) accS accAt
= atRecvIndex s (Send ty accS) accAt

||| If we know that `ty` is at index `n` in the list of sent types
||| and that `f` is a function defined using an interleaving of steps
||| comprising `m` sending steps then `ty` is at index `m + n` in `f s`.
atSendIndex : Seen _ m f ->
(s : Session) ->
AtIndex ty (SendTypes s) n ->
AtIndex ty (SendTypes (f s)) (m + n)
atSendIndex None accS accAt = accAt
atSendIndex (Recv ty s) accS accAt
= atSendIndex s (Recv ty accS) accAt
atSendIndex (Send ty s) accS accAt
= rewrite plusSuccRightSucc (pred m) n in
atSendIndex s (Send ty accS) (S accAt)


||| A (bidirectional) channel is parametrised by session it must respect.
|||
||| It is implemented in terms of two low-level channels: one for sending
||| and one for receiving. This ensures that we never are in a situation
||| where a thread with session (Send Nat (Recv String ...)) sends a natural
||| number and subsequently performs a receive before the other party
||| to the communication had time to grab the Nat thus receiving it
||| instead of a String.
|||
||| The low-level channels can only carry values of a single type. And so
||| they are given respective union types corresponding to the types that
||| can be sent on the one hand and the ones that can be received on the
||| other.
||| These union types are tagged unions where if `ty` is at index `k` in
||| the list of types `tys` then `(k, v)` is a value of `Union m tys`
||| provided that `v` has type `m ty`.
|||
||| `sendStep`, `recvStep`, `seePrefix`, and `seen` encode the fact that
||| we have already performed some of the protocol and so the low-level
||| channels' respective types necessarily mention types that we won't
||| see anymore.
export
record Channel (s : Session) where
constructor MkChannel
{sendStep : Nat}
{recvStep : Nat}
{0 seenPrefix : Session -> Session}
0 seen : Seen recvStep sendStep seenPrefix

sendChan : Threads.Channel (Union (SendTypes (seenPrefix s)))
recvChan : Threads.Channel (Union (RecvTypes (seenPrefix s)))

||| Linear version of `die`
export
die1 : LinearIO io => String -> L1 io a
die1 err = do
x <- die err
pure1 x

||| Consume a linear channel with a `Recv ty` step at the head of the
||| session type in order to obtain a value of type `ty` together with
||| a linear channel for the rest of the session.
export
recv : LinearIO io =>
Channel (Recv ty s) -@
L1 io (Res ty (const (Channel s)))
recv (MkChannel {recvStep} seen sendCh recvCh) = do
x@(Element k prf val) <- channelGet recvCh
-- Here we check that we got the right message by projecting out of
-- the union type using the current `recvStep`. Both ends should be
-- in sync because of the `RecvDualTypes` and `SendDualTypes` lemmas.
let Just val = prj (recvStep + 0) (atRecvIndex seen (Recv ty s) Z) x
| Nothing => die1 "Error: invalid recv expected \{show recvStep} but got \{show k}"
pure1 (val # MkChannel (Recv ty seen) sendCh recvCh)


||| Consume a linear channel with a `Send ty` step at the head of the
||| session type in order to send a value of type `ty` and obtain a
||| linear channel for the rest of the session.
export
send : LinearIO io =>
(1 _ : Channel (Send ty s)) ->
ty ->
L1 io (Channel s)
send (MkChannel {sendStep} seen sendCh recvCh) x = do
let val = inj (sendStep + 0) (atSendIndex seen (Send ty s) Z) x
channelPut sendCh val
pure1 (MkChannel (Send ty seen) sendCh recvCh)

||| Discard the channel provided that the session has reached its `End`.
export
end : LinearIO io => Channel End -@ L io ()
end (MkChannel _ _ _) = do
pure ()

||| Given a session, create a bidirectional communiaction channel and
||| return its two endpoints
export
makeChannel :
LinearIO io =>
(0 s : Session) ->
L1 io (LPair (Channel s) (Channel (Dual s)))
makeChannel s = do
sendChan <- Threads.makeChannel
recvChan <- Threads.makeChannel
let 1 posCh : Channel s
:= MkChannel None sendChan recvChan
let 1 negCh : Channel (Dual s)
:= MkChannel None
(rewrite SendDualTypes s in recvChan)
(rewrite RecvDualTypes s in sendChan)
pure1 (posCh # negCh)

||| Given a session and two functions communicating according to that
||| sesion, we can run the two programs concurrently and collect their
||| final results.
export
fork : (0 s : Session) ->
(Channel s -@ L IO a) -@
(Channel (Dual s) -@ L IO b) -@
L IO (a, b)
fork s kA kB = do
let 1 io = makeChannel s
(posCh # negCh) <- io
par (kA posCh) (kB negCh)
5 changes: 4 additions & 1 deletion libs/linear/linear.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ modules = Control.Linear.LIO,
Data.Linear.LMaybe,
Data.Linear.LNat,
Data.Linear.LVect,
Data.Linear.Notation
Data.Linear.Notation,

System.Concurrency.Linear,
System.Concurrency.Session

0 comments on commit e115aca

Please sign in to comment.