Skip to content

Commit 293c400

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents fd9ff98 + b4623bc commit 293c400

15 files changed

+356
-131
lines changed

dev/cuda/Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-
3030
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@
3131

3232
# Build all targets
33-
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm
33+
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute
3434
all: $(TARGETS)
3535
all_ptx: $(TARGETS:%=%.ptx)
3636
all_sass: $(TARGETS:%=%.sass)
@@ -64,6 +64,8 @@ matmul_backward: matmul_backward.cu
6464
adamw: adamw.cu
6565
global_norm: global_norm.cu
6666

67+
permute: permute.cu
68+
6769
# NCCL communication kernels
6870
nccl_all_reduce: nccl_all_reduce.cu
6971
$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce

dev/cuda/attention_backward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,7 @@ int main(int argc, char **argv) {
11371137
free(dinp);
11381138
free(dpreatt);
11391139
free(datt);
1140+
free(h_dinp);
11401141
cudaCheck(cudaFree(d_inp));
11411142
cudaCheck(cudaFree(d_qkvr));
11421143
cudaCheck(cudaFree(d_preatt));

dev/cuda/attention_forward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ int main(int argc, char **argv) {
13771377
cudaCheck(cudaFree(d_preatt));
13781378
cudaCheck(cudaFree(d_att));
13791379
cudaCheck(cudaFree(d_inp));
1380+
cudaCheck(cudaFree(d_stats));
13801381
cublasDestroy(cublas_handle);
13811382

13821383
#ifdef ENABLE_CUDNN

dev/cuda/classifier_fused.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ int main(int argc, char **argv) {
766766
cudaCheck(cudaFree(d_logits));
767767
cudaCheck(cudaFree(d_dlosses));
768768
cudaCheck(cudaFree(d_targets));
769+
cudaCheck(cudaFree(d_dlogits_no_pad));
769770

770771
return 0;
771772
}

dev/cuda/nccl_all_reduce.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,6 @@ int main(int argc, char **argv) {
193193

194194
free(all_reduce_buffer_host);
195195
cudaCheck(cudaFree(all_reduce_buffer));
196+
cudaCheck(cudaFree(all_reduce_buffer_recv));
196197
multi_gpu_config_free(&multi_gpu_config);
197198
}

dev/cuda/permute.cu

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
Kernels to demonstrate permute operation.
3+
4+
Compile example:
5+
nvcc -O3 permute.cu -o permute
6+
7+
The goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2).
8+
9+
Before permutation, we need to understand how to access elements in a flattened (linear) form of the matrix.
10+
11+
Given:
12+
13+
dim1 = size of the 1st dimension
14+
dim2 = size of the 2nd dimension
15+
dim3 = size of the 3rd dimension
16+
dim4 = size of the 4th dimension
17+
18+
For any element in a 4D matrix at position (i1, i2, i3, i4), where:
19+
20+
i1 is the index in dimension 1
21+
i2 is the index in dimension 2
22+
i3 is the index in dimension 3
23+
i4 is the index in dimension 4
24+
25+
If you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations.
26+
Initially, it might take some time to grasp, but with practice, you'll develop a mental model for it.
27+
28+
To calculate the indices, use the following formulas:
29+
30+
i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
31+
i2 = (idx / (dim3 * dim4)) % dim2;
32+
i3 = (idx / dim4) % dim3;
33+
i4 = idx % dim4;
34+
35+
Pattern Explanation:
36+
To find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions.
37+
Then, perform modulo operation with the current dimension.
38+
39+
40+
41+
The linear index in a flattened 1D array is calculated as:
42+
linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4
43+
This linear index uniquely identifies the position of the element in the 1D array.
44+
45+
To permute the matrix, we need to rearrange the indices according to the new shape.
46+
In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).
47+
48+
The new dimension post permutation will be as follow:
49+
50+
dim1 becomes the new 3rd dimension.
51+
dim2 becomes the new 4th dimension.
52+
dim3 becomes the new 2nd dimension.
53+
dim4 becomes the new 1st dimension.
54+
55+
permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
56+
57+
Here's how this works:
58+
59+
i4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block.
60+
i3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in.
61+
i1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in.
62+
i2: This gives the offset within the current i1 block.
63+
64+
Lastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix.
65+
66+
67+
--------------------------------------------------------------------------------------------------------------------------------------------------------
68+
69+
Similarly we can follow the above approach to permute matrices of any dimensions.
70+
71+
*/
72+
73+
74+
#include <cuda_runtime.h>
75+
#include <stdio.h>
76+
#include <stdlib.h>
77+
#include <cmath>
78+
79+
// CPU function to permute a 4D matrix
80+
void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
81+
int total_threads = dim1 * dim2 * dim3 * dim4;
82+
83+
for (int idx = 0; idx < total_threads; idx++) {
84+
// Calculate the 4D indices from the linear index
85+
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
86+
int i2 = (idx / (dim3 * dim4)) % dim2;
87+
int i3 = (idx / dim4) % dim3;
88+
int i4 = idx % dim4;
89+
90+
// Compute the new index for the permuted matrix
91+
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
92+
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
93+
out_matrix[permuted_idx] = matrix[idx];
94+
}
95+
}
96+
97+
// CUDA kernel to permute a 4D matrix
98+
__global__ void permute_cuda(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
99+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
100+
101+
// Ensure index is within bounds
102+
if (idx < dim1 * dim2 * dim3 * dim4) {
103+
// Calculate the 4D indices from the linear index
104+
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
105+
int i2 = (idx / (dim3 * dim4)) % dim2;
106+
int i3 = (idx / dim4) % dim3;
107+
int i4 = idx % dim4;
108+
109+
// Compute the new index for the permuted matrix
110+
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
111+
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
112+
out_matrix[permuted_idx] = matrix[idx];
113+
}
114+
}
115+
116+
// Function to check if the CUDA permutation result matches the CPU result
117+
bool verify_results(const float* permuted_matrix_cuda, const float* permuted_matrix_cpu, int totalElements) {
118+
bool success = true;
119+
for (int i = 0; i < totalElements; i++) {
120+
// Allow a small tolerance for floating-point comparison
121+
if (fabs(permuted_matrix_cuda[i] - permuted_matrix_cpu[i]) > 1e-5) {
122+
success = false;
123+
printf("Permute Operation Failed\n");
124+
printf("CPU: %f\n", permuted_matrix_cpu[i]);
125+
printf("CUDA: %f\n", permuted_matrix_cuda[i]);
126+
break; // Exit early on the first failure
127+
}
128+
}
129+
if (success) {
130+
printf("Permute Operation Passed\n");
131+
}
132+
return success;
133+
}
134+
135+
// Function to initialize the matrix with random values
136+
void initialize_matrix(float* mat, int dim_1, int dim_2, int dim_3, int dim_4) {
137+
for (int i = 0; i < dim_1 * dim_2 * dim_3 * dim_4; ++i) {
138+
mat[i] = static_cast<float>(rand()) / RAND_MAX;
139+
}
140+
printf("Matrix Initialized\n");
141+
}
142+
143+
int main() {
144+
int dim_1 = 24;
145+
int dim_2 = 42;
146+
int dim_3 = 20;
147+
int dim_4 = 32;
148+
149+
// Set up the device
150+
int deviceIdx = 0;
151+
cudaSetDevice(deviceIdx);
152+
cudaDeviceProp deviceProp;
153+
cudaGetDeviceProperties(&deviceProp, deviceIdx);
154+
printf("Device %d: %s\n", deviceIdx, deviceProp.name);
155+
156+
// Allocate host memory
157+
float* matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
158+
float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
159+
float* permuted_matrix_cpu = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
160+
161+
// Initialize the matrix with random values
162+
initialize_matrix(matrix, dim_1, dim_2, dim_3, dim_4);
163+
164+
// Allocate device memory
165+
float *d_matrix, *d_permuted_matrix;
166+
cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
167+
cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
168+
169+
// Copy matrix from host to device
170+
cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);
171+
172+
// Perform permutation on CPU
173+
permute_cpu(matrix, permuted_matrix_cpu, dim_1, dim_2, dim_3, dim_4);
174+
175+
// Define block and grid sizes
176+
dim3 blockSize(256);
177+
int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;
178+
int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size
179+
180+
// Launch CUDA kernel to perform permutation
181+
permute_cuda<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
182+
cudaDeviceSynchronize(); // Ensure kernel execution is complete
183+
184+
// Copy the result from device to host
185+
cudaMemcpy(permuted_matrix, d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyDeviceToHost);
186+
187+
// Verify results
188+
verify_results(permuted_matrix, permuted_matrix_cpu, dim_1 * dim_2 * dim_3 * dim_4);
189+
190+
// Free allocated memory
191+
free(matrix);
192+
free(permuted_matrix);
193+
free(permuted_matrix_cpu);
194+
cudaFree(d_matrix);
195+
cudaFree(d_permuted_matrix);
196+
197+
return 0;
198+
}
199+

dev/cuda/trimat_forward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ int main(int argc, char **argv) {
643643
free(inp);
644644
cudaCheck(cudaFree(d_out));
645645
cudaCheck(cudaFree(d_inp));
646+
cudaCheck(cudaFree(d_qkvr));
646647
cublasDestroy(cublas_handle);
647648

648649
return 0;

dev/unistd.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <string.h>
1414
#include <direct.h> // for _mkdir and _stat
1515
#include <io.h> // needed for _access below and _findfirst, _findnext, _findclose
16+
#pragma comment(lib, "Ws2_32.lib") // Link Ws2_32.lib for socket functions
17+
#include <winsock2.h>
1618

1719
#define CLOCK_MONOTONIC 0
1820
static inline int clock_gettime(int ignore_variable, struct timespec* tv)

llmc/cuda_utils.cuh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,36 @@ __device__ void store128cg(ElementType* target, Packed128<ElementType> value) {
7979
typedef Packed128<float> f128;
8080
typedef Packed128<floatX> x128;
8181

82+
// ----------------------------------------------------------------------------
83+
// DType support
84+
85+
// enumerator to indentify the datatype of a tensor.
86+
enum class DType : uint8_t {
87+
FP32, FP16, BF16
88+
};
89+
90+
// Given a datatype enum, returns the underlying number of bytes
91+
// for a scalar of that type
92+
size_t sizeof_dtype(DType type) {
93+
switch (type) {
94+
case DType::FP32:
95+
return sizeof(float);
96+
case DType::FP16:
97+
return sizeof(half);
98+
case DType::BF16:
99+
return sizeof(nv_bfloat16);
100+
default: // handle or get compiler warning
101+
fprintf(stderr, "Unknown datatype\n");
102+
exit(EXIT_FAILURE);
103+
}
104+
}
105+
106+
DType dtype_of(float* f) { return DType::FP32; }
107+
DType dtype_of(nv_bfloat16 * f) { return DType::BF16; }
108+
DType dtype_of(half * f) { return DType::FP16; }
109+
110+
111+
82112
// ----------------------------------------------------------------------------
83113
// Copy, cast functions
84114

llmc/cudnn_att.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// TODO this currently duplicates some of the utilities from the main file
44

55
#define NOMINMAX
6+
#include <unistd.h>
67
#include "cudnn_att.h"
78
#include <cudnn_frontend.h>
89

0 commit comments

Comments
 (0)