@@ -171,6 +171,111 @@ instance Randomizable Conv2dSpec' Conv2d' where
171
171
, params = a
172
172
}
173
173
174
+
175
+ --------------------------------------------------------------------------------
176
+ -- Multi-Head Attention Data Structures
177
+ --------------------------------------------------------------------------------
178
+
179
+ -- | Specification for initializing a MultiHeadAttention module.
180
+ data MultiHeadAttentionSpec = MultiHeadAttentionSpec
181
+ { mhaEmbedDim :: Int -- ^ Model embedding dimension
182
+ , mhaNumHeads :: Int -- ^ Number of attention heads
183
+ } deriving (Show , Eq )
184
+
185
+ -- | Data type that holds parameters for Multi-Head Attention.
186
+ data MultiHeadAttention = MultiHeadAttention
187
+ { wQ :: Linear -- ^ Linear projection for the queries
188
+ , wK :: Linear -- ^ Linear projection for the keys
189
+ , wV :: Linear -- ^ Linear projection for the values
190
+ , wO :: Linear -- ^ Final linear projection after combining heads
191
+ , headDim :: Int -- ^ Dimension per head = embedDim / numHeads
192
+ , nHeads :: Int -- ^ Number of attention heads
193
+ } deriving (Show )
194
+
195
+ -- | Create random parameters for Multi-Head Attention given the specification.
196
+ instance Randomizable MultiHeadAttentionSpec MultiHeadAttention where
197
+ sample MultiHeadAttentionSpec {.. } = do
198
+ let headDim = mhaEmbedDim `Prelude.div` mhaNumHeads
199
+ wQ' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
200
+ wK' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
201
+ wV' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
202
+ wO' <- sample $ LinearSpec mhaEmbedDim mhaEmbedDim
203
+ return $ MultiHeadAttention
204
+ { wQ = wQ'
205
+ , wK = wK'
206
+ , wV = wV'
207
+ , wO = wO'
208
+ , headDim = headDim
209
+ , nHeads = mhaNumHeads
210
+ }
211
+
212
+ --------------------------------------------------------------------------------
213
+ -- Forward Pass (Scaled Dot-Product Attention + Multi-Head Logic)
214
+ --------------------------------------------------------------------------------
215
+
216
+ -- | Compute scaled dot-product attention for query, key, value tensors.
217
+ -- The typical shape for q, k, v is:
218
+ -- [batchSize, numHeads, seqLen, headDim]
219
+ --
220
+ -- Returns: [batchSize, numHeads, seqLen, headDim]
221
+ scaledDotProductAttention
222
+ :: Tensor -- ^ Queries (q)
223
+ -> Tensor -- ^ Keys (k)
224
+ -> Tensor -- ^ Values (v)
225
+ -> Tensor -- ^ Output (contextual embeddings)
226
+ scaledDotProductAttention q k v =
227
+ let -- q*k^T -> [batchSize, numHeads, seqLen, seqLen]
228
+ dk = fromIntegral (shape q !! 3 ) -- headDim
229
+ scores = (q `matmul` transpose2D k) / Torch. sqrt (asTensor (dk :: Float ))
230
+ attnWeights = softmax (Dim (- 1 )) scores -- softmax over last dim (seqLen)
231
+ output = attnWeights `matmul` v -- multiply by values
232
+ in output
233
+
234
+ -- | Forward pass for Multi-Head Attention (without any mask or dropout, minimal).
235
+ multiHeadAttention
236
+ :: MultiHeadAttention
237
+ -> Tensor -- ^ Input queries [batchSize, seqLen, embedDim]
238
+ -> Tensor -- ^ Input keys [batchSize, seqLen, embedDim]
239
+ -> Tensor -- ^ Input values [batchSize, seqLen, embedDim]
240
+ -> Tensor -- ^ Output [batchSize, seqLen, embedDim]
241
+ multiHeadAttention MultiHeadAttention {.. } queries keys values =
242
+ let
243
+ -- Project inputs to Q, K, V space
244
+ q = linear wQ queries -- [batchSize, seqLen, embedDim]
245
+ k = linear wK keys -- [batchSize, seqLen, embedDim]
246
+ v = linear wV values -- [batchSize, seqLen, embedDim]
247
+
248
+ batchSize = shape queries !! 0
249
+ seqLen = shape queries !! 1
250
+
251
+ -- Reshape for multi-head: [batchSize, seqLen, nHeads*headDim]
252
+ -- -> [batchSize, seqLen, nHeads, headDim]
253
+ -- -> [batchSize, nHeads, seqLen, headDim]
254
+ reshapeForHeads t =
255
+ let t' = view [batchSize, seqLen, nHeads* headDim] t
256
+ t''= view [batchSize, seqLen, nHeads, headDim] t'
257
+ in permute [0 ,2 ,1 ,3 ] t'' -- reorder dimensions to [batchSize, nHeads, seqLen, headDim]
258
+
259
+ qHeads = reshapeForHeads q
260
+ kHeads = reshapeForHeads k
261
+ vHeads = reshapeForHeads v
262
+
263
+ -- Apply scaled dot-product attention
264
+ attnOutput = scaledDotProductAttention qHeads kHeads vHeads
265
+ -- shape: [batchSize, nHeads, seqLen, headDim]
266
+
267
+ -- Convert back: [batchSize, nHeads, seqLen, headDim]
268
+ -- -> [batchSize, seqLen, nHeads, headDim]
269
+ -- -> [batchSize, seqLen, nHeads*headDim]
270
+ attnOutputTrans = permute [0 ,2 ,1 ,3 ] attnOutput
271
+ combinedHeads = view [batchSize, seqLen, nHeads* headDim] attnOutputTrans
272
+
273
+ -- Final linear
274
+ out = linear wO combinedHeads -- [batchSize, seqLen, embedDim]
275
+ in out
276
+
277
+
278
+
174
279
-- Generate HasForwardAssoc instances from HasForward instances.
175
280
instanceForwardAssocs
176
281
[ [t | Linear |]
0 commit comments