-
Notifications
You must be signed in to change notification settings - Fork 0
/
Main.hs
164 lines (137 loc) · 5.81 KB
/
Main.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PolyKinds #-}
module Main where
import Control.Category
import Control.Monad.Indexed
import Data.Composition
import Data.Distributive
import Data.Random
import Linear
import Linear.V
import Numeric.AD
import Prelude hiding (id, (.))
import IndexedTardis
import WrapIndex
-- | The type representing a single sigmoid layer
data Layer f a b = Layer {weights :: V b (V a f), biases :: V b f}
-- | A network is a series of compatible layers
data Network f a b where
Id :: Network f a a
Lr :: (Dim a, Dim b) => Layer f a b -> Network f b c -> Network f a c
-- | Networks form a category in the same way lists form a monoid
instance Category (Network f) where
id = Id
m . Id = m
m . Lr l n = Lr l (m . n)
-- | Left fold over a network. Apply a function layer by layer forwards through
-- the network.
--
-- If you don't want to have the type being folded to be paramtrised, or to be
-- parametrised in a different way, you can use wrapper types like Flip and
-- Const to introduce and move a parameter.
foldlNetwork :: (forall a b. (Dim a, Dim b) => v a -> Layer f a b -> v b) ->
v a -> Network f a b -> v b
foldlNetwork f v Id = v
foldlNetwork f v (Lr l n) = foldlNetwork f (f v l) n
-- | Right fold over a network. Apply a function layer by layer backwards
-- through the network.
foldrNetwork :: (forall a b. (Dim a, Dim b) => Layer f a b -> v b -> v a) ->
v b -> Network f a b -> v a
foldrNetwork f v Id = v
foldrNetwork f v (Lr l n) = f l (foldrNetwork f v n)
-- | Traverse a network using an 'IxApplicative'.
traverseNetwork :: IxApplicative ia =>
(forall a b. (Dim a, Dim b) => Layer f a b -> ia a b (Layer g a b)) ->
Network f a b -> ia a b (Network g a b)
traverseNetwork _ Id = ireturn Id
traverseNetwork f (Lr l n) = Lr `imap` f l `iap` traverseNetwork f n
newtype Flip f b a = Flip {unflip :: f a b}
feedForward1 :: (Floating f, Dim a, Dim b) => (f -> f) ->
V a f -> Layer f a b -> V b f
feedForward1 sigmoid i Layer{..} = sigmoid <$> weights !* i + biases
-- | Given a sigmoid function, for example the logistic function
-- @\x -> 1 / (1 + exp (- x))@, or @tanh@ or @atan@, feed an input vector
-- through the network.
feedForward :: Floating f => (f -> f) -> V a f -> Network f a b -> V b f
feedForward sigmoid =
unflip .: foldlNetwork (Flip .: feedForward1 sigmoid . unflip) . Flip
backPropogate1 :: (Floating f, Dim a, Dim b) =>
(forall f. Floating f => f -> f) -> f ->
Layer f a b -> ITardis (Flip V f) (Flip V f) a b (Layer f a b)
backPropogate1 sigmoid lr Layer{..} = ITardis $ \(Flip db, Flip za) ->
let aa = sigmoid za
zb = weights !* aa + biases
da = transpose weights !* db * diff sigmoid za
in (Layer {weights = weights - lr *!! db `outer` aa,
biases = biases - lr *^ db},
(Flip da, Flip zb))
-- | Given a pair of input and expected output, a sigmoid function, and a cost
-- function, for example @qd@, run the backpropogation algorithm for a network
-- and a single input.
backPropogate' :: (Dim b, Floating f) => (V a f, V b f) ->
(forall f. Floating f => f -> f) ->
(forall f. Floating f => V b f -> V b f -> f) -> f ->
Network f a b -> Network f a b
backPropogate' (x, y) sigmoid cost lr network =
let (network', (_, Flip zl)) = runITardis
(traverseNetwork (backPropogate1 sigmoid lr) network)
(Flip dl, Flip x)
al = sigmoid zl
dl = grad (cost $ fmap auto y) al * diff sigmoid zl
in network'
-- | 'backPropogate\'' with a sensible default sigmoid and cost.
backPropogate :: (Dim b, Floating f) => (V a f, V b f) -> f ->
Network f a b -> Network f a b
backPropogate (x, y) = backPropogate' (x, y) logistic qd
where
logistic x = 1 / (1 + exp (- x))
batch :: Dim w => Network f a b -> Network (V w f) a b
batch Id = Id
batch (Lr Layer{..} n) = Lr l' (batch n)
where
l' = Layer {weights = fmap (fmap pure) weights, biases = fmap pure biases}
unbatch :: (Dim w, Fractional f) => Network (V w f) a b -> Network f a b
unbatch Id = Id
unbatch (Lr Layer{..} n) = Lr l' (unbatch n)
where
l' = Layer {weights = fmap (fmap average) weights,
biases = fmap average biases}
average v = sum v / fromIntegral (dim v)
runMiniBatch' :: (Dim w, Dim a, Dim b, Floating f) => V w (V a f, V b f) ->
(forall f. Floating f => f -> f) ->
(forall f. Floating f => V b f -> V b f -> f) -> f ->
Network f a b -> Network f a b
runMiniBatch' xys sigmoid cost lr =
unbatch .
backPropogate' (collect fst xys, collect snd xys) sigmoid cost (pure lr) .
batch
runMiniBatch :: (Dim w, Dim a, Dim b, Floating f) => V w (V a f, V b f) ->
f ->
Network f a b -> Network f a b
runMiniBatch xys lr =
unbatch .
backPropogate (collect fst xys, collect snd xys) (pure lr) .
batch
emptyLayer :: (Dim a, Dim b) => Layer () a b
emptyLayer = Layer {weights = pure (pure ()), biases = pure ()}
randomizeLayer' :: (Dim a, Dim b, Distribution d t) => d t -> Layer x a b -> RVar (Layer t a b)
randomizeLayer' dist Layer{..} =
Layer <$>
traverse (traverse (\_ -> rvar dist)) weights <*>
traverse (\_ -> rvar dist) biases
randomize' :: Distribution d t => d t -> Network x a b -> RVar (Network t a b)
randomize' dist = iunwrap . traverseNetwork (IWrap . randomizeLayer' dist)
randomize :: Distribution Normal t => Network x a b -> RVar (Network t a b)
randomize = randomize' StdNormal
main :: IO ()
main = do
let net0 = Lr (emptyLayer :: Layer () 784 30) $
Lr (emptyLayer :: Layer () 30 30) $
Lr (emptyLayer :: Layer () 30 10) $
Id
net <- runRVar (randomize net0) StdRandom :: IO (Network Double 784 10)
return ()