Skip to content

Commit 4ac9242

Browse files
Update
1 parent c0ed0b8 commit 4ac9242

File tree

6 files changed

+298
-182
lines changed

6 files changed

+298
-182
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,7 @@ mlp model input = forward model input
4242

4343
For one input, take the outputs of all layers, then compare the shapes and values of all the layers.
4444

45+
## Overlay layer
46+
47+
## Concatenate layer
48+

hasktorch-compose.cabal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ common base
2020
build-depends:
2121
base >= 4.12 && < 5
2222
, hasktorch >= 0.2 && < 0.3
23+
, HList
2324

2425
common binary-base
2526
import: base
@@ -36,8 +37,8 @@ library
3637
Torch.Typed.Compose.NN
3738
Torch.Typed.Compose.Models
3839
hs-source-dirs: src
39-
default-extensions: Strict
40-
, StrictData
40+
-- default-extensions: Strict
41+
-- , StrictData
4142

4243
test-suite spec
4344
type: exitcode-stdio-1.0
@@ -47,6 +48,7 @@ test-suite spec
4748
, hasktorch
4849
, hasktorch-compose
4950
, hspec
51+
, HList
5052

5153
executable example
5254
import: binary-base

src/Torch/Compose.hs

Lines changed: 107 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1-
{-# LANGUAGE TypeOperators#-}
2-
{-# LANGUAGE FlexibleInstances#-}
3-
{-# LANGUAGE MultiParamTypeClasses#-}
4-
{-# LANGUAGE UndecidableInstances#-}
5-
{-# LANGUAGE DeriveGeneric#-}
1+
{-# LANGUAGE DataKinds #-}
62
{-# LANGUAGE DeriveAnyClass#-}
3+
{-# LANGUAGE DeriveGeneric#-}
74
{-# LANGUAGE DuplicateRecordFields#-}
5+
{-# LANGUAGE FlexibleContexts#-}
6+
{-# LANGUAGE FlexibleInstances#-}
7+
{-# LANGUAGE FunctionalDependencies#-}
8+
{-# LANGUAGE GADTs#-}
9+
{-# LANGUAGE MultiParamTypeClasses#-}
10+
{-# LANGUAGE OverloadedRecordDot#-}
11+
{-# LANGUAGE PartialTypeSignatures #-}
12+
{-# LANGUAGE PolyKinds #-}
813
{-# LANGUAGE RecordWildCards#-}
914
{-# LANGUAGE ScopedTypeVariables#-}
1015
{-# LANGUAGE TypeApplications#-}
1116
{-# LANGUAGE TypeFamilies#-}
12-
{-# LANGUAGE GADTs#-}
13-
{-# LANGUAGE OverloadedRecordDot#-}
14-
{-# LANGUAGE FlexibleContexts#-}
15-
{-# LANGUAGE FunctionalDependencies#-}
1617
{-# LANGUAGE TypeFamilyDependencies#-}
17-
{-# LANGUAGE PartialTypeSignatures #-}
18+
{-# LANGUAGE TypeOperators#-}
19+
{-# LANGUAGE UndecidableInstances#-}
1820

1921

2022
module Torch.Compose where
@@ -23,13 +25,31 @@ import Torch
2325
import Torch.NN
2426
import Torch.Functional
2527
import GHC.Generics hiding ((:+:))
28+
-- import Data.Void
29+
import Data.HList
30+
import Data.HList (hAppend)
31+
import Data.Kind
32+
import Data.Coerce
33+
import Control.Exception
34+
import System.IO.Unsafe
35+
36+
instance (Randomizable spec0 f0, Randomizable (HList spec1) (HList f1)) => Randomizable (HList (spec0 ': spec1)) (HList (f0 ': f1)) where
37+
sample (HCons s0 s1) = do
38+
f0 <- sample s0
39+
f1 <- sample s1
40+
return (f0 .*. f1)
2641

27-
data (:>>:) a b = (:>>:)
28-
{ head :: a
29-
, tail :: b
30-
} deriving (Show, Eq, Generic)
42+
instance Randomizable (HList '[]) (HList '[]) where
43+
sample HNil = do
44+
return HNil
45+
46+
instance (HasForward f a b, HasForward (HList g) b c) => HasForward (HList (f ': g)) a c where
47+
forward (HCons f g) a = forward g (forward f a)
48+
forwardStoch (HCons f g) a = forwardStoch f a >>= forwardStoch g
3149

32-
infixr 5 :>>:
50+
instance HasForward (HList '[]) a a where
51+
forward _ = id
52+
forwardStoch _ = pure
3353

3454
data (://:) a b = Fanout
3555
{ head :: a
@@ -46,16 +66,6 @@ data (:++:) a b = Concat
4666
, tail :: b
4767
} deriving (Show, Eq, Generic)
4868

49-
instance (Randomizable spec0 f0, Randomizable spec1 f1) => Randomizable (spec0 :>>: spec1) (f0 :>>: f1) where
50-
sample ((:>>:) s0 s1) = do
51-
f0 <- sample s0
52-
f1 <- sample s1
53-
return ((:>>:) f0 f1)
54-
55-
instance (HasForward f a b, HasForward g b c) => HasForward (f :>>: g) a c where
56-
forward ((:>>:) f g) a = forward g (forward f a)
57-
forwardStoch ((:>>:) f g) a = forwardStoch f a >>= forwardStoch g
58-
5969
instance (Randomizable spec0 f0, Randomizable spec1 f1) => Randomizable (spec0 ://: spec1) (f0 ://: f1) where
6070
sample (Fanout s0 s1) = do
6171
f0 <- sample s0
@@ -117,58 +127,78 @@ instance (HasForward a b b) => HasForward (Replicate a) b b where
117127
forwardStoch (Replicate []) input = pure input
118128
forwardStoch (Replicate (a:ax)) input = forwardStoch (Replicate ax) =<< forwardStoch a input
119129

120-
type family LastLayer x where
121-
LastLayer (a :>>: b) = LastLayer b
122-
LastLayer x = x
123-
124-
class HasLast x r | x -> r where
125-
getLast :: x -> r
126-
127-
instance HasLast b r => HasLast (a :>>: b) r where
128-
getLast ((:>>:) _ b) = getLast b
129-
130-
instance HasLast a a where
131-
getLast = id
132-
133-
type family FirstLayer x where
134-
FirstLayer (a :>>: b) = a
135-
FirstLayer x = x
136-
137-
class HasFirst x r | x -> r where
138-
getFirst :: x -> r
139-
140-
instance HasFirst a r => HasFirst (a :>>: b) r where
141-
getFirst ((:>>:) a _) = getFirst a
142-
143-
instance HasFirst a a where
144-
getFirst = id
145-
146130
class HasForwardAssoc f a where
147-
type ForwardResult f a
131+
type ForwardResult f a :: Type
148132
forwardAssoc :: f -> a -> ForwardResult f a
149133

150-
class HasOutputs f a where
151-
type Outputs f a
152-
toOutputs :: f -> a -> Outputs f a
153-
154-
instance (HasForwardAssoc f0 a, HasOutputs f0 a, HasOutputs f1 (ForwardResult f0 a)) => HasOutputs (f0 :>>: f1) a where
155-
type Outputs (f0 :>>: f1) a = Outputs f0 a :>>: Outputs f1 (ForwardResult f0 a)
156-
toOutputs ((:>>:) f0 f1) a = (:>>:) (toOutputs f0 a) (toOutputs f1 (forwardAssoc f0 a))
157-
158-
class HasInputs f a where
159-
type Inputs f a
160-
toInputs :: f -> a -> Inputs f a
161-
162-
instance (HasForwardAssoc f0 a, HasInputs f0 a, HasInputs f1 (ForwardResult f0 a)) => HasInputs (f0 :>>: f1) a where
163-
type Inputs (f0 :>>: f1) a = Inputs f0 a :>>: Inputs f1 (ForwardResult f0 a)
164-
toInputs ((:>>:) f0 f1) a = (:>>:) (toInputs f0 a) (toInputs f1 (forwardAssoc f0 a))
165-
166-
167-
class HasOutputShapes f a where
168-
type OutputShapes f a
169-
toOutputShapes :: f -> a -> OutputShapes f a
170-
171-
instance (HasForwardAssoc f0 a, HasOutputShapes f0 a, HasOutputShapes f1 (ForwardResult f0 a)) => HasOutputShapes (f0 :>>: f1) a where
172-
type OutputShapes (f0 :>>: f1) a = OutputShapes f0 a :>>: OutputShapes f1 (ForwardResult f0 a)
173-
toOutputShapes ((:>>:) f0 f1) a = (:>>:) (toOutputShapes f0 a) (toOutputShapes f1 (forwardAssoc f0 a))
174-
134+
toHList :: x -> HList '[x]
135+
toHList x = HCons x HNil
136+
137+
instance (HasForwardAssoc f a) => HasForwardAssoc f (HList '[a]) where
138+
type ForwardResult f (HList '[a]) = HList '[ForwardResult f a]
139+
forwardAssoc f (HCons a HNil) = toHList $ forwardAssoc f a
140+
141+
dropLastLayer :: (Coercible a (HList xs1), HRevApp xs2 '[x] xs1, HRevApp xs2 '[] sx, HRevApp xs1 '[] (x : xs2), HRevApp sx '[] xs2) => a -> HList sx
142+
dropLastLayer m = hReverse (hDrop (Proxy :: Proxy (HSucc HZero)) (hReverse (coerce m)))
143+
144+
addLastLayer :: HAppend l1 (HList '[e]) => l1 -> e -> HAppendR l1 (HList '[e])
145+
addLastLayer a b = a `hAppend` (b .*. HNil)
146+
147+
getLastLayer :: (Coercible a (HList l1), HRevApp l1 '[] (e : l)) => a -> e
148+
getLastLayer a = hLast (coerce a)
149+
150+
hScanl :: forall f z ls xs1 sx xs2. (HScanr f z ls xs1, HRevApp xs1 '[] sx, HRevApp sx '[] xs1, HRevApp xs2 '[] ls, HRevApp ls '[] xs2) => f -> z -> HList xs2 -> HList sx
151+
hScanl a b c = hReverse $ hScanr a b (hReverse c)
152+
153+
safeEval :: forall a. a -> Maybe a
154+
safeEval x = unsafePerformIO $ do
155+
result <- try (evaluate @a x) :: IO (Either SomeException a)
156+
case result of
157+
Left _ -> return Nothing
158+
Right v -> return (Just v)
159+
160+
type family ForwardMap (xs :: [*]) (a :: *) :: [*] where
161+
ForwardMap '[] _ = '[]
162+
ForwardMap (x ': xs) a = ForwardResult x a ': ForwardMap xs (ForwardResult x a)
163+
164+
class Outputs xs input where
165+
toOutputs' :: HList xs -> input -> HList (ForwardMap xs input)
166+
167+
instance HasForwardAssoc x a => HasForwardAssoc x (Maybe a) where
168+
type ForwardResult x (Maybe a) = Maybe (ForwardResult x a)
169+
forwardAssoc x (Just a) = Just $ forwardAssoc x a
170+
forwardAssoc x Nothing = Nothing
171+
172+
173+
instance (HasForwardAssoc x a, Outputs xs (ForwardResult x a)) => Outputs (x ': xs) a where
174+
toOutputs' (HCons x xs) a =
175+
let out = forwardAssoc x a
176+
in HCons out $ toOutputs' xs out
177+
178+
instance Outputs '[] a where
179+
toOutputs' _ _ = HNil
180+
181+
toOutputs ::
182+
(Coercible a (HList xs),
183+
Outputs xs input
184+
) =>
185+
a -> input -> HList (ForwardMap xs input)
186+
toOutputs f = toOutputs' (coerce f)
187+
188+
toOutputShapes ::
189+
(Coercible a (HList xs),
190+
HMapAux HList (Tensor -> [Int]) (ForwardMap xs input) b,
191+
SameLength' b (ForwardMap xs input),
192+
SameLength' (ForwardMap xs input) b, Outputs xs input
193+
) =>
194+
a -> input -> HList b
195+
toOutputShapes f a = hMap shape (toOutputs f a)
196+
197+
toMaybeOutputShapes ::
198+
(Coercible a (HList xs),
199+
HMapAux HList (Tensor -> Maybe [Int]) (ForwardMap xs input) b,
200+
SameLength' b (ForwardMap xs input),
201+
SameLength' (ForwardMap xs input) b, Outputs xs input
202+
) =>
203+
a -> input -> HList b
204+
toMaybeOutputShapes f a = hMap (safeEval . shape) (toOutputs f a)

src/Torch/Compose/Models.hs

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
{-# LANGUAGE GADTs#-}
1212
{-# LANGUAGE OverloadedRecordDot#-}
1313
{-# LANGUAGE FlexibleContexts#-}
14+
{-# LANGUAGE DataKinds#-}
1415

1516
module Torch.Compose.Models where
1617

1718
import Torch
1819
import Torch.Compose
1920
import Torch.Compose.NN
2021
import GHC.Generics hiding ((:+:))
22+
import Data.HList
2123

2224
vgg16Spec numClass =
2325
let maxPool2dSpec = MaxPool2dSpec
@@ -28,34 +30,37 @@ vgg16Spec numClass =
2830
, ceilMode = Ceil
2931
}
3032
vggClassifierSpec =
31-
LinearSpec (512 * 7 * 7) 4096 :>>:
32-
ReluSpec :>>:
33-
DropoutSpec 0.5 :>>:
34-
LinearSpec 4096 4096 :>>:
35-
ReluSpec :>>:
36-
DropoutSpec 0.5 :>>:
37-
LinearSpec 4096 numClass
33+
LinearSpec (512 * 7 * 7) 4096 .*.
34+
ReluSpec .*.
35+
DropoutSpec 0.5 .*.
36+
LinearSpec 4096 4096 .*.
37+
ReluSpec .*.
38+
DropoutSpec 0.5 .*.
39+
LinearSpec 4096 numClass .*.
40+
HNil
41+
conv2dSpec inChannel outChannel kernelHeight kernelWidth =
42+
Conv2dSpec' inChannel outChannel kernelHeight kernelWidth (1,1) (0,0)
3843
in
39-
Conv2dSpec 3 64 3 3 :>>:
40-
Conv2dSpec 64 64 3 3 :>>:
41-
maxPool2dSpec :>>:
42-
Conv2dSpec 64 128 3 3 :>>:
43-
Conv2dSpec 128 128 3 3 :>>:
44-
maxPool2dSpec :>>:
45-
Conv2dSpec 128 256 3 3 :>>:
46-
Conv2dSpec 256 256 3 3 :>>:
47-
Conv2dSpec 256 256 3 3 :>>:
48-
maxPool2dSpec :>>:
49-
Conv2dSpec 256 512 3 3 :>>:
50-
Conv2dSpec 512 512 3 3 :>>:
51-
Conv2dSpec 512 512 3 3 :>>:
52-
AdaptiveAvgPool2dSpec (7,7) :>>:
53-
ReshapeSpec [1,512*7*7] :>>:
44+
conv2dSpec 3 64 3 3 .*.
45+
conv2dSpec 64 64 3 3 .*.
46+
maxPool2dSpec .*.
47+
conv2dSpec 64 128 3 3 .*.
48+
conv2dSpec 128 128 3 3 .*.
49+
maxPool2dSpec .*.
50+
conv2dSpec 128 256 3 3 .*.
51+
conv2dSpec 256 256 3 3 .*.
52+
conv2dSpec 256 256 3 3 .*.
53+
maxPool2dSpec .*.
54+
conv2dSpec 256 512 3 3 .*.
55+
conv2dSpec 512 512 3 3 .*.
56+
conv2dSpec 512 512 3 3 .*.
57+
AdaptiveAvgPool2dSpec (7,7) .*.
58+
ReshapeSpec [1,512*7*7] .*.
5459
vggClassifierSpec
5560

56-
57-
resnetSpec numClass =
58-
Conv2dSpec 3 64 7 7 :>>:
59-
BatchNorm2dSpec 64 :>>:
60-
ReluSpec :>>:
61-
Conv2dSpec 64 64 3 3
61+
-- resnetSpec numClass =
62+
-- conv2dSpec 3 64 7 7 .*.
63+
-- BatchNorm2dSpec 64 .*.
64+
-- ReluSpec .*.
65+
-- conv2dSpec 64 64 3 3 .*.
66+
-- HNil

0 commit comments

Comments
 (0)