Skip to content

Commit 9ed3522

Browse files
committed
Merge branch 'master' into Nexes_CQ_10
2 parents 2a8dbf8 + 1d48e98 commit 9ed3522

20 files changed

+727
-70
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ ifdef GGML_CUDA
611611

612612
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
613613
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
614-
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
614+
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
615615
else
616616
ifneq ('', '$(wildcard /opt/cuda)')
617617
CUDA_PATH ?= /opt/cuda

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Typically finetunes of the base models below are supported as well.
112112
- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
113113
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
114114
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
115+
- JS/TS (Programmable Prompt Engine CLI): [offline-ai/cli](https://github.com/offline-ai/cli)
115116
- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm)
116117
- Typescript/Wasm (nicer API, available on npm): [ngxson/wllama](https://github.com/ngxson/wllama)
117118
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)

examples/perplexity/perplexity.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
444444
}
445445
LOG("%.2f minutes\n", total_seconds / 60.0);
446446
}
447-
LOG("\n");
448447

449448
//LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
450449
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
@@ -638,7 +637,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
638637
}
639638
LOG("%.2f minutes\n", total_seconds / 60.0);
640639
}
641-
LOG("\n");
642640

643641
for (int seq = 0; seq < n_seq_batch; seq++) {
644642
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);

flake.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ggml/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ if (GGML_CUDA)
364364
if (GGML_MUSA)
365365
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
366366
foreach(SOURCE ${GGML_SOURCES_CUDA})
367-
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
367+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
368368
endforeach()
369369
endif()
370370

ggml/src/ggml-aarch64.c

Lines changed: 503 additions & 34 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda.cu

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ggml-cuda/tsembd.cuh"
3535
#include "ggml-cuda/unary.cuh"
3636
#include "ggml-cuda/upscale.cuh"
37+
#include "ggml-cuda/rwkv-wkv.cuh"
3738

3839
#include <algorithm>
3940
#include <array>
@@ -135,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
135136
return res;
136137
#else
137138

138-
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
139+
#if !defined(GGML_USE_HIPBLAS)
139140
cudaError_t err;
140141
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
141142
{
@@ -148,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
148149
return err;
149150
#else
150151
return cudaMalloc(ptr, size);
151-
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
152+
#endif // !defined(GGML_USE_HIPBLAS)
152153

153154
#endif
154155
}
@@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22432244
case GGML_UNARY_OP_HARDSWISH:
22442245
ggml_cuda_op_hardswish(ctx, dst);
22452246
break;
2247+
case GGML_UNARY_OP_EXP:
2248+
ggml_cuda_op_exp(ctx, dst);
2249+
break;
22462250
default:
22472251
return false;
22482252
}
@@ -2345,6 +2349,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23452349
case GGML_OP_CROSS_ENTROPY_LOSS:
23462350
ggml_cuda_cross_entropy_loss(ctx, dst);
23472351
break;
2352+
case GGML_OP_RWKV_WKV:
2353+
ggml_cuda_op_rwkv_wkv(ctx, dst);
2354+
break;
23482355
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23492356
ggml_cuda_cross_entropy_loss_back(ctx, dst);
23502357
break;
@@ -2797,6 +2804,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27972804
case GGML_UNARY_OP_HARDSWISH:
27982805
case GGML_UNARY_OP_GELU_QUICK:
27992806
case GGML_UNARY_OP_TANH:
2807+
case GGML_UNARY_OP_EXP:
28002808
return ggml_is_contiguous(op->src[0]);
28012809
default:
28022810
return false;
@@ -2813,6 +2821,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28132821
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
28142822
return false;
28152823
}
2824+
#ifdef GGML_USE_MUSA
2825+
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
2826+
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
2827+
return false;
2828+
}
2829+
#endif // GGML_USE_MUSA
28162830
switch (a->type) {
28172831
case GGML_TYPE_F32:
28182832
case GGML_TYPE_F16:
@@ -2836,6 +2850,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28362850
case GGML_TYPE_IQ3_XXS:
28372851
case GGML_TYPE_IQ4_NL:
28382852
case GGML_TYPE_IQ4_XS:
2853+
#ifdef GGML_USE_MUSA
2854+
if (a->type == GGML_TYPE_Q3_K) {
2855+
return false;
2856+
}
2857+
#endif // GGML_USE_MUSA
28392858
return true;
28402859
default:
28412860
return false;
@@ -2958,20 +2977,24 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29582977
case GGML_OP_ARANGE:
29592978
case GGML_OP_TIMESTEP_EMBEDDING:
29602979
case GGML_OP_LEAKY_RELU:
2980+
case GGML_OP_RWKV_WKV:
29612981
return true;
2962-
case GGML_OP_FLASH_ATTN_EXT:
2963-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2964-
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
2965-
#else
2982+
case GGML_OP_FLASH_ATTN_EXT: {
2983+
#ifndef FLASH_ATTN_AVAILABLE
2984+
return false;
2985+
#endif
2986+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2987+
return true;
2988+
}
29662989
if (op->src[0]->ne[0] == 128) {
29672990
return true;
29682991
}
2969-
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2992+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
29702993
return true;
29712994
}
2972-
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2973-
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2974-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2995+
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
2996+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2997+
}
29752998
case GGML_OP_CROSS_ENTROPY_LOSS:
29762999
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
29773000
case GGML_OP_OPT_STEP_ADAMW:

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
5151
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
5252
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
53+
#define CC_QY1 210
54+
#define CC_QY2 220
5355

5456
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
5557

@@ -134,6 +136,10 @@ typedef float2 dfloat2;
134136
#define INT8_MMA_AVAILABLE
135137
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136138

139+
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140+
#define FLASH_ATTN_AVAILABLE
141+
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
142+
137143
static constexpr bool fast_fp16_available(const int cc) {
138144
return cc >= CC_PASCAL && cc != 610;
139145
}

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
4444
const int ne1,
4545
const int ne2,
4646
const int ne3) {
47+
#ifndef FLASH_ATTN_AVAILABLE
48+
NO_DEVICE_CODE;
49+
return;
50+
#endif // FLASH_ATTN_AVAILABLE
4751
// Skip unused kernel variants for faster compilation:
4852
if (use_logit_softcap && !(D == 128 || D == 256)) {
4953
NO_DEVICE_CODE;
5054
return;
5155
}
5256

53-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
57+
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5458

5559
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
5660
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
314314
}
315315

316316
if (!fast_fp16_available(cc)) {
317-
if (Q->ne[1] <= 8) {
317+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
318318
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319319
} else {
320320
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);

0 commit comments

Comments
 (0)