Skip to content

Commit a1bd3f2

Browse files
Multi head attention
1 parent 92c345f commit a1bd3f2

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

flake.nix

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
inherit system;
1515
config.allowUnfree = true;
1616
config.cudaSupport = system == "x86_64-linux";
17+
# config.ihaskell.packages = pkgs: with pkgs; [
18+
# hasktorch
19+
# hvega
20+
# ];
1721
};
1822
ghcWithHasktorch = pkgs.haskellPackages.ghcWithPackages (pkgs: with pkgs; [
1923
hasktorch
@@ -26,6 +30,7 @@
2630
ghcWithHasktorch
2731
cabal-install
2832
stack
33+
# ihaskell
2934
];
3035
shellHook = ''
3136
source ${git}/share/bash-completion/completions/git-prompt.sh

src/Torch/Compose/NN.hs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,111 @@ instance Randomizable Conv2dSpec' Conv2d' where
171171
, params = a
172172
}
173173

174+
175+
--------------------------------------------------------------------------------
176+
-- Multi-Head Attention Data Structures
177+
--------------------------------------------------------------------------------
178+
179+
-- | Specification for initializing a MultiHeadAttention module.
180+
data MultiHeadAttentionSpec = MultiHeadAttentionSpec
181+
{ mhaEmbedDim :: Int -- ^ Model embedding dimension
182+
, mhaNumHeads :: Int -- ^ Number of attention heads
183+
} deriving (Show, Eq)
184+
185+
-- | Data type that holds parameters for Multi-Head Attention.
186+
data MultiHeadAttention = MultiHeadAttention
187+
{ wQ :: Linear -- ^ Linear projection for the queries
188+
, wK :: Linear -- ^ Linear projection for the keys
189+
, wV :: Linear -- ^ Linear projection for the values
190+
, wO :: Linear -- ^ Final linear projection after combining heads
191+
, headDim :: Int -- ^ Dimension per head = embedDim / numHeads
192+
, nHeads :: Int -- ^ Number of attention heads
193+
} deriving (Show)
194+
195+
-- | Create random parameters for Multi-Head Attention given the specification.
196+
instance Randomizable MultiHeadAttentionSpec MultiHeadAttention where
197+
sample MultiHeadAttentionSpec{..} = do
198+
let headDim = mhaEmbedDim `Prelude.div` mhaNumHeads
199+
wQ' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
200+
wK' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
201+
wV' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
202+
wO' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
203+
return $ MultiHeadAttention
204+
{ wQ = wQ'
205+
, wK = wK'
206+
, wV = wV'
207+
, wO = wO'
208+
, headDim = headDim
209+
, nHeads = mhaNumHeads
210+
}
211+
212+
--------------------------------------------------------------------------------
213+
-- Forward Pass (Scaled Dot-Product Attention + Multi-Head Logic)
214+
--------------------------------------------------------------------------------
215+
216+
-- | Compute scaled dot-product attention for query, key, value tensors.
217+
-- The typical shape for q, k, v is:
218+
-- [batchSize, numHeads, seqLen, headDim]
219+
--
220+
-- Returns: [batchSize, numHeads, seqLen, headDim]
221+
scaledDotProductAttention
222+
:: Tensor -- ^ Queries (q)
223+
-> Tensor -- ^ Keys (k)
224+
-> Tensor -- ^ Values (v)
225+
-> Tensor -- ^ Output (contextual embeddings)
226+
scaledDotProductAttention q k v =
227+
let -- q*k^T -> [batchSize, numHeads, seqLen, seqLen]
228+
dk = fromIntegral (shape q !! 3) -- headDim
229+
scores = (q `matmul` transpose2D k) / Torch.sqrt (asTensor (dk :: Float))
230+
attnWeights = softmax (Dim (-1)) scores -- softmax over last dim (seqLen)
231+
output = attnWeights `matmul` v -- multiply by values
232+
in output
233+
234+
-- | Forward pass for Multi-Head Attention (without any mask or dropout, minimal).
235+
multiHeadAttention
236+
:: MultiHeadAttention
237+
-> Tensor -- ^ Input queries [batchSize, seqLen, embedDim]
238+
-> Tensor -- ^ Input keys [batchSize, seqLen, embedDim]
239+
-> Tensor -- ^ Input values [batchSize, seqLen, embedDim]
240+
-> Tensor -- ^ Output [batchSize, seqLen, embedDim]
241+
multiHeadAttention MultiHeadAttention{..} queries keys values =
242+
let
243+
-- Project inputs to Q, K, V space
244+
q = linear wQ queries -- [batchSize, seqLen, embedDim]
245+
k = linear wK keys -- [batchSize, seqLen, embedDim]
246+
v = linear wV values -- [batchSize, seqLen, embedDim]
247+
248+
batchSize = shape queries !! 0
249+
seqLen = shape queries !! 1
250+
251+
-- Reshape for multi-head: [batchSize, seqLen, nHeads*headDim]
252+
-- -> [batchSize, seqLen, nHeads, headDim]
253+
-- -> [batchSize, nHeads, seqLen, headDim]
254+
reshapeForHeads t =
255+
let t' = view [batchSize, seqLen, nHeads*headDim] t
256+
t''= view [batchSize, seqLen, nHeads, headDim] t'
257+
in permute [0,2,1,3] t'' -- reorder dimensions to [batchSize, nHeads, seqLen, headDim]
258+
259+
qHeads = reshapeForHeads q
260+
kHeads = reshapeForHeads k
261+
vHeads = reshapeForHeads v
262+
263+
-- Apply scaled dot-product attention
264+
attnOutput = scaledDotProductAttention qHeads kHeads vHeads
265+
-- shape: [batchSize, nHeads, seqLen, headDim]
266+
267+
-- Convert back: [batchSize, nHeads, seqLen, headDim]
268+
-- -> [batchSize, seqLen, nHeads, headDim]
269+
-- -> [batchSize, seqLen, nHeads*headDim]
270+
attnOutputTrans = permute [0,2,1,3] attnOutput
271+
combinedHeads = view [batchSize, seqLen, nHeads*headDim] attnOutputTrans
272+
273+
-- Final linear
274+
out = linear wO combinedHeads -- [batchSize, seqLen, embedDim]
275+
in out
276+
277+
278+
174279
-- Generate HasForwardAssoc instances from HasForward instances.
175280
instanceForwardAssocs
176281
[ [t| Linear |]

0 commit comments

Comments
 (0)