Skip to content

Commit fdb47ca

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents c9949e9 + bdb0fb5 commit fdb47ca

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

train_gpt2.cu

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,12 @@ void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]
330330

331331
void* acts_memory;
332332
cudaCheck(cudaMalloc((void**)&acts_memory, bytes));
333-
#ifdef BUILD_AMD
334-
cudaCheck(cudaMemset(acts_memory, 0, bytes)); // TODO: fix this properly :p
335-
#endif
333+
334+
// cudaMalloc does not guarantee initial memory values so we memset the allocation here
335+
// this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed
336+
// todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer
337+
cudaCheck(cudaMemset(acts_memory, 0, bytes));
338+
336339
char* acts_memory_iterator = (char*)acts_memory;
337340
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
338341
// extra protection so we don't accidentally use an empty buffer
@@ -626,9 +629,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
626629
cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));
627630
cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));
628631
} else {
629-
// validate B,T is consistent with how we've allocated the memory before
630-
// in principle we could get more clever here in the future, for now this is safest
631-
if (B != model->batch_size || T != model->seq_len) {
632+
// validate B,T are not larger than the values used at initialisation
633+
// (smaller B,T are okay for inference only)
634+
if (B > model->batch_size || T > model->seq_len) {
632635
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
633636
exit(EXIT_FAILURE);
634637
}
@@ -692,6 +695,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
692695
attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream);
693696
#else
694697
floatX* l_att = acts.att + l * B * NH * T * T;
698+
if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent)
699+
cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX)));
700+
}
695701
// these are only needed as scratchpads for the forward pass, but
696702
// need not be stored for backward
697703
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
@@ -1756,14 +1762,14 @@ int main(int argc, char *argv[]) {
17561762
printf("generating:\n---\n");
17571763
for (int t = 1; t < genT; t++) {
17581764
NvtxRange generation_range("Generation step", t);
1759-
// note that inference is very wasteful here because for each token
1760-
// we re-calculate the forward pass for all of (B,T) positions from scratch
1761-
// but the inference here is just for sanity checking anyway
1762-
// and we can maybe optimize a bit more later, with careful tests
1763-
gpt2_forward(&model, gen_tokens, B, T);
1764-
// furthermore, below we're only using b=0 (i.e. the first row) of all B rows
1765-
// we're in principle running B "inference streams" in parallel here
1766-
// only using position 0 because it's a bit faster (copy less probs from GPU -> CPU)
1765+
// we try not to be too wasteful for inference by not calculating all of B,T
1766+
// Using a smaller B is always bit-for-bit identical, but T is more tricky
1767+
// for non-CUDNN, we need to make sure the attention buffer is memset to 0
1768+
// for cuDNN, it might suddenly decide to use a slightly different algorithm...
1769+
// on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical
1770+
// (but even if it wasn't fully identical that's probably not the end of the world)
1771+
// note this is still somewhat wasteful because we don't have a KV cache!
1772+
gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256));
17671773
// get the V-dimensional vector probs[0, t-1, :]
17681774
floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
17691775
// move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)

0 commit comments

Comments
 (0)