1
- {-# LANGUAGE TypeOperators#-}
2
- {-# LANGUAGE FlexibleInstances#-}
3
- {-# LANGUAGE MultiParamTypeClasses#-}
4
- {-# LANGUAGE UndecidableInstances#-}
5
- {-# LANGUAGE DeriveGeneric#-}
1
+ {-# LANGUAGE DataKinds #-}
6
2
{-# LANGUAGE DeriveAnyClass#-}
3
+ {-# LANGUAGE DeriveGeneric#-}
7
4
{-# LANGUAGE DuplicateRecordFields#-}
5
+ {-# LANGUAGE FlexibleContexts#-}
6
+ {-# LANGUAGE FlexibleInstances#-}
7
+ {-# LANGUAGE FunctionalDependencies#-}
8
+ {-# LANGUAGE GADTs#-}
9
+ {-# LANGUAGE MultiParamTypeClasses#-}
10
+ {-# LANGUAGE OverloadedRecordDot#-}
11
+ {-# LANGUAGE PartialTypeSignatures #-}
12
+ {-# LANGUAGE PolyKinds #-}
8
13
{-# LANGUAGE RecordWildCards#-}
9
14
{-# LANGUAGE ScopedTypeVariables#-}
10
15
{-# LANGUAGE TypeApplications#-}
11
16
{-# LANGUAGE TypeFamilies#-}
12
- {-# LANGUAGE GADTs#-}
13
- {-# LANGUAGE OverloadedRecordDot#-}
14
- {-# LANGUAGE FlexibleContexts#-}
15
- {-# LANGUAGE FunctionalDependencies#-}
16
17
{-# LANGUAGE TypeFamilyDependencies#-}
17
- {-# LANGUAGE PartialTypeSignatures #-}
18
+ {-# LANGUAGE TypeOperators#-}
19
+ {-# LANGUAGE UndecidableInstances#-}
18
20
19
21
20
22
module Torch.Compose where
@@ -23,13 +25,31 @@ import Torch
23
25
import Torch.NN
24
26
import Torch.Functional
25
27
import GHC.Generics hiding ((:+:) )
28
+ -- import Data.Void
29
+ import Data.HList
30
+ import Data.HList (hAppend )
31
+ import Data.Kind
32
+ import Data.Coerce
33
+ import Control.Exception
34
+ import System.IO.Unsafe
35
+
36
+ instance (Randomizable spec0 f0 , Randomizable (HList spec1 ) (HList f1 )) => Randomizable (HList (spec0 ': spec1 )) (HList (f0 ': f1 )) where
37
+ sample (HCons s0 s1) = do
38
+ f0 <- sample s0
39
+ f1 <- sample s1
40
+ return (f0 .*. f1)
26
41
27
- data (:>>: ) a b = (:>>:)
28
- { head :: a
29
- , tail :: b
30
- } deriving (Show , Eq , Generic )
42
+ instance Randomizable (HList '[] ) (HList '[] ) where
43
+ sample HNil = do
44
+ return HNil
45
+
46
+ instance (HasForward f a b , HasForward (HList g ) b c ) => HasForward (HList (f ': g )) a c where
47
+ forward (HCons f g) a = forward g (forward f a)
48
+ forwardStoch (HCons f g) a = forwardStoch f a >>= forwardStoch g
31
49
32
- infixr 5 :>>:
50
+ instance HasForward (HList '[] ) a a where
51
+ forward _ = id
52
+ forwardStoch _ = pure
33
53
34
54
data (://: ) a b = Fanout
35
55
{ head :: a
@@ -46,16 +66,6 @@ data (:++:) a b = Concat
46
66
, tail :: b
47
67
} deriving (Show , Eq , Generic )
48
68
49
- instance (Randomizable spec0 f0 , Randomizable spec1 f1 ) => Randomizable (spec0 :>>: spec1 ) (f0 :>>: f1 ) where
50
- sample ((:>>:) s0 s1) = do
51
- f0 <- sample s0
52
- f1 <- sample s1
53
- return ((:>>:) f0 f1)
54
-
55
- instance (HasForward f a b , HasForward g b c ) => HasForward (f :>>: g ) a c where
56
- forward ((:>>:) f g) a = forward g (forward f a)
57
- forwardStoch ((:>>:) f g) a = forwardStoch f a >>= forwardStoch g
58
-
59
69
instance (Randomizable spec0 f0 , Randomizable spec1 f1 ) => Randomizable (spec0 ://: spec1 ) (f0 ://: f1 ) where
60
70
sample (Fanout s0 s1) = do
61
71
f0 <- sample s0
@@ -117,58 +127,78 @@ instance (HasForward a b b) => HasForward (Replicate a) b b where
117
127
forwardStoch (Replicate [] ) input = pure input
118
128
forwardStoch (Replicate (a: ax)) input = forwardStoch (Replicate ax) =<< forwardStoch a input
119
129
120
- type family LastLayer x where
121
- LastLayer (a :>>: b ) = LastLayer b
122
- LastLayer x = x
123
-
124
- class HasLast x r | x -> r where
125
- getLast :: x -> r
126
-
127
- instance HasLast b r => HasLast (a :>>: b ) r where
128
- getLast ((:>>:) _ b) = getLast b
129
-
130
- instance HasLast a a where
131
- getLast = id
132
-
133
- type family FirstLayer x where
134
- FirstLayer (a :>>: b ) = a
135
- FirstLayer x = x
136
-
137
- class HasFirst x r | x -> r where
138
- getFirst :: x -> r
139
-
140
- instance HasFirst a r => HasFirst (a :>>: b ) r where
141
- getFirst ((:>>:) a _) = getFirst a
142
-
143
- instance HasFirst a a where
144
- getFirst = id
145
-
146
130
class HasForwardAssoc f a where
147
- type ForwardResult f a
131
+ type ForwardResult f a :: Type
148
132
forwardAssoc :: f -> a -> ForwardResult f a
149
133
150
- class HasOutputs f a where
151
- type Outputs f a
152
- toOutputs :: f -> a -> Outputs f a
153
-
154
- instance (HasForwardAssoc f0 a , HasOutputs f0 a , HasOutputs f1 (ForwardResult f0 a )) => HasOutputs (f0 :>>: f1 ) a where
155
- type Outputs (f0 :>>: f1 ) a = Outputs f0 a :>>: Outputs f1 (ForwardResult f0 a )
156
- toOutputs ((:>>:) f0 f1) a = (:>>:) (toOutputs f0 a) (toOutputs f1 (forwardAssoc f0 a))
157
-
158
- class HasInputs f a where
159
- type Inputs f a
160
- toInputs :: f -> a -> Inputs f a
161
-
162
- instance (HasForwardAssoc f0 a , HasInputs f0 a , HasInputs f1 (ForwardResult f0 a )) => HasInputs (f0 :>>: f1 ) a where
163
- type Inputs (f0 :>>: f1 ) a = Inputs f0 a :>>: Inputs f1 (ForwardResult f0 a )
164
- toInputs ((:>>:) f0 f1) a = (:>>:) (toInputs f0 a) (toInputs f1 (forwardAssoc f0 a))
165
-
166
-
167
- class HasOutputShapes f a where
168
- type OutputShapes f a
169
- toOutputShapes :: f -> a -> OutputShapes f a
170
-
171
- instance (HasForwardAssoc f0 a , HasOutputShapes f0 a , HasOutputShapes f1 (ForwardResult f0 a )) => HasOutputShapes (f0 :>>: f1 ) a where
172
- type OutputShapes (f0 :>>: f1 ) a = OutputShapes f0 a :>>: OutputShapes f1 (ForwardResult f0 a )
173
- toOutputShapes ((:>>:) f0 f1) a = (:>>:) (toOutputShapes f0 a) (toOutputShapes f1 (forwardAssoc f0 a))
174
-
134
+ toHList :: x -> HList '[x ]
135
+ toHList x = HCons x HNil
136
+
137
+ instance (HasForwardAssoc f a ) => HasForwardAssoc f (HList '[a ]) where
138
+ type ForwardResult f (HList '[a ]) = HList '[ForwardResult f a ]
139
+ forwardAssoc f (HCons a HNil ) = toHList $ forwardAssoc f a
140
+
141
+ dropLastLayer :: (Coercible a (HList xs1 ), HRevApp xs2 '[x ] xs1 , HRevApp xs2 '[] sx , HRevApp xs1 '[] (x : xs2 ), HRevApp sx '[] xs2 ) => a -> HList sx
142
+ dropLastLayer m = hReverse (hDrop (Proxy :: Proxy (HSucc HZero )) (hReverse (coerce m)))
143
+
144
+ addLastLayer :: HAppend l1 (HList '[e ]) => l1 -> e -> HAppendR l1 (HList '[e ])
145
+ addLastLayer a b = a `hAppend` (b .*. HNil )
146
+
147
+ getLastLayer :: (Coercible a (HList l1 ), HRevApp l1 '[] (e : l )) => a -> e
148
+ getLastLayer a = hLast (coerce a)
149
+
150
+ hScanl :: forall f z ls xs1 sx xs2 . (HScanr f z ls xs1 , HRevApp xs1 '[] sx , HRevApp sx '[] xs1 , HRevApp xs2 '[] ls , HRevApp ls '[] xs2 ) => f -> z -> HList xs2 -> HList sx
151
+ hScanl a b c = hReverse $ hScanr a b (hReverse c)
152
+
153
+ safeEval :: forall a . a -> Maybe a
154
+ safeEval x = unsafePerformIO $ do
155
+ result <- try (evaluate @ a x) :: IO (Either SomeException a )
156
+ case result of
157
+ Left _ -> return Nothing
158
+ Right v -> return (Just v)
159
+
160
+ type family ForwardMap (xs :: [* ]) (a :: * ) :: [* ] where
161
+ ForwardMap '[] _ = '[]
162
+ ForwardMap (x ': xs ) a = ForwardResult x a ': ForwardMap xs (ForwardResult x a )
163
+
164
+ class Outputs xs input where
165
+ toOutputs' :: HList xs -> input -> HList (ForwardMap xs input )
166
+
167
+ instance HasForwardAssoc x a => HasForwardAssoc x (Maybe a ) where
168
+ type ForwardResult x (Maybe a ) = Maybe (ForwardResult x a )
169
+ forwardAssoc x (Just a) = Just $ forwardAssoc x a
170
+ forwardAssoc x Nothing = Nothing
171
+
172
+
173
+ instance (HasForwardAssoc x a , Outputs xs (ForwardResult x a )) => Outputs (x ': xs ) a where
174
+ toOutputs' (HCons x xs) a =
175
+ let out = forwardAssoc x a
176
+ in HCons out $ toOutputs' xs out
177
+
178
+ instance Outputs '[] a where
179
+ toOutputs' _ _ = HNil
180
+
181
+ toOutputs ::
182
+ (Coercible a (HList xs ),
183
+ Outputs xs input
184
+ ) =>
185
+ a -> input -> HList (ForwardMap xs input )
186
+ toOutputs f = toOutputs' (coerce f)
187
+
188
+ toOutputShapes ::
189
+ (Coercible a (HList xs ),
190
+ HMapAux HList (Tensor -> [Int ]) (ForwardMap xs input ) b ,
191
+ SameLength' b (ForwardMap xs input ),
192
+ SameLength' (ForwardMap xs input ) b , Outputs xs input
193
+ ) =>
194
+ a -> input -> HList b
195
+ toOutputShapes f a = hMap shape (toOutputs f a)
196
+
197
+ toMaybeOutputShapes ::
198
+ (Coercible a (HList xs ),
199
+ HMapAux HList (Tensor -> Maybe [Int ]) (ForwardMap xs input ) b ,
200
+ SameLength' b (ForwardMap xs input ),
201
+ SameLength' (ForwardMap xs input ) b , Outputs xs input
202
+ ) =>
203
+ a -> input -> HList b
204
+ toMaybeOutputShapes f a = hMap (safeEval . shape) (toOutputs f a)
0 commit comments