@@ -330,9 +330,12 @@ void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]
330
330
331
331
void * acts_memory;
332
332
cudaCheck (cudaMalloc ((void **)&acts_memory, bytes));
333
- #ifdef BUILD_AMD
334
- cudaCheck (cudaMemset (acts_memory, 0 , bytes)); // TODO: fix this properly :p
335
- #endif
333
+
334
+ // cudaMalloc does not guarantee initial memory values so we memset the allocation here
335
+ // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed
336
+ // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer
337
+ cudaCheck (cudaMemset (acts_memory, 0 , bytes));
338
+
336
339
char * acts_memory_iterator = (char *)acts_memory;
337
340
for (size_t i = 0 ; i < NUM_ACTIVATION_TENSORS; i++) {
338
341
// extra protection so we don't accidentally use an empty buffer
@@ -626,9 +629,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
626
629
cudaCheck (cudaMalloc (((void **)&model->accumulated_mean_loss ), sizeof (float )));
627
630
cudaCheck (cudaMallocHost ((void **)&model->cpu_losses , B * T * sizeof (float )));
628
631
} else {
629
- // validate B,T is consistent with how we've allocated the memory before
630
- // in principle we could get more clever here in the future, for now this is safest
631
- if (B != model->batch_size || T != model->seq_len ) {
632
+ // validate B,T are not larger than the values used at initialisation
633
+ // (smaller B,T are okay for inference only)
634
+ if (B > model->batch_size || T > model->seq_len ) {
632
635
printf (" Model: B=%d T=%d, Desired: B=%d T=%d\n " , model->batch_size , model->seq_len , (int )B, (int )T);
633
636
exit (EXIT_FAILURE);
634
637
}
@@ -692,6 +695,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
692
695
attention_forward_cudnn (l_atty, (float *)l_att, l_qkvr, B, T, NH, C, main_stream);
693
696
#else
694
697
floatX* l_att = acts.att + l * B * NH * T * T;
698
+ if (T != model->seq_len ) { // unused parts of attention buffer must be zeroed (T-dependent)
699
+ cudaCheck (cudaMemset (l_att, 0 , B * NH * T * T * sizeof (floatX)));
700
+ }
695
701
// these are only needed as scratchpads for the forward pass, but
696
702
// need not be stored for backward
697
703
matmul_forward_cublaslt (scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3 *C, main_stream);
@@ -1756,14 +1762,14 @@ int main(int argc, char *argv[]) {
1756
1762
printf (" generating:\n ---\n " );
1757
1763
for (int t = 1 ; t < genT; t++) {
1758
1764
NvtxRange generation_range (" Generation step" , t);
1759
- // note that inference is very wasteful here because for each token
1760
- // we re-calculate the forward pass for all of (B,T) positions from scratch
1761
- // but the inference here is just for sanity checking anyway
1762
- // and we can maybe optimize a bit more later, with careful tests
1763
- gpt2_forward (&model, gen_tokens, B , T);
1764
- // furthermore, below we're only using b=0 (i.e. the first row) of all B rows
1765
- // we're in principle running B "inference streams" in parallel here
1766
- // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU)
1765
+ // we try not to be too wasteful for inference by not calculating all of B,T
1766
+ // Using a smaller B is always bit- for-bit identical, but T is more tricky
1767
+ // for non-CUDNN, we need to make sure the attention buffer is memset to 0
1768
+ // for cuDNN, it might suddenly decide to use a slightly different algorithm...
1769
+ // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2 , T >= 256 seems bit-for-bit identical
1770
+ // (but even if it wasn't fully identical that's probably not the end of the world)
1771
+ // note this is still somewhat wasteful because we don't have a KV cache!
1772
+ gpt2_forward (&model, gen_tokens, 1 , CEIL_DIV (t, min (T, 256 )) * min (T, 256 ));
1767
1773
// get the V-dimensional vector probs[0, t-1, :]
1768
1774
floatX* logits = model.acts .output + (t - 1 ) * model.config .padded_vocab_size ;
1769
1775
// move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)
0 commit comments