Skip to content

Commit 06dfde3

Browse files
committed
llama : add basic support for offloading moe with CUDA
1 parent 2cbcba8 commit 06dfde3

File tree

3 files changed

+61
-19
lines changed

3 files changed

+61
-19
lines changed

ggml-cuda.cu

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8242,15 +8242,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
82428242
// TODO: mmq/mmv support
82438243
#endif
82448244

8245-
const struct ggml_tensor * ids = src0;
8246-
const int32_t id = dst->op_params[0];
8247-
const int32_t n_as = dst->op_params[1];
8245+
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
82488246

8249-
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8247+
const struct ggml_tensor * ids = src0;
8248+
const int32_t id = ((int32_t *) dst->op_params)[0];
8249+
const int32_t n_as = ((int32_t *) dst->op_params)[1];
82508250

82518251
std::vector<char> ids_host(ggml_nbytes(ids));
8252-
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8253-
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8252+
8253+
if (ids->backend == GGML_BACKEND_GPU) {
8254+
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8255+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8256+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8257+
} else {
8258+
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
8259+
}
82548260

82558261
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
82568262
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
@@ -8264,7 +8270,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
82648270
src1_row.ne[1] = 1;
82658271
dst_row.ne[1] = 1;
82668272

8267-
src1_row.extra = &src1_row_extra;
8273+
if (src1->backend == GGML_BACKEND_GPU) {
8274+
src1_row.extra = &src1_row_extra;
8275+
}
82688276
dst_row.extra = &dst_row_extra;
82698277

82708278
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -8278,7 +8286,12 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
82788286

82798287
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
82808288

8281-
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8289+
if (src1->backend == GGML_BACKEND_GPU) {
8290+
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8291+
} else {
8292+
src1_row.data = (char *) src1->data + i01*src1->nb[1];
8293+
}
8294+
82828295
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
82838296

82848297
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
@@ -8694,7 +8707,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
86948707
func = ggml_cuda_repeat;
86958708
break;
86968709
case GGML_OP_GET_ROWS:
8697-
func = ggml_cuda_get_rows;
8710+
if (ggml_is_contiguous(tensor->src[1])) {
8711+
func = ggml_cuda_get_rows;
8712+
}
86988713
break;
86998714
case GGML_OP_DUP:
87008715
func = ggml_cuda_dup;

ggml.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4105,7 +4105,6 @@ struct ggml_tensor * ggml_mul_mat_id(
41054105
result->src[0] = ids;
41064106
result->src[1] = b;
41074107

4108-
// TODO: n_as is the selected experts, but it should be the total number of experts
41094108
for (int i = 0; i < n_as; i++) {
41104109
struct ggml_tensor * a = as[i];
41114110
GGML_ASSERT(ggml_are_same_shape(as[0], a));

llama.cpp

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4247,16 +4247,25 @@ struct llm_build_context {
42474247
const int n_experts_per_tok = 2;
42484248

42494249
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
4250+
cb(logits, "ffn_moe_logits", il);
4251+
42504252
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
4253+
cb(probs, "ffn_moe_probs", il);
42514254

42524255
// select experts
42534256
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4254-
ggml_tensor * weights =
4255-
ggml_reshape_2d(ctx0,
4256-
ggml_get_rows(ctx0,
4257-
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts),
4257+
ggml_tensor * weights = ggml_get_rows(ctx0,
4258+
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
4259+
cb(weights, "ffn_moe_weights", il);
4260+
4261+
weights = ggml_reshape_2d(ctx0, weights,
42584262
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
4259-
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok]
4263+
4264+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
4265+
cb(weights_sum, "ffn_moe_weights_sum", il);
4266+
4267+
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
4268+
cb(weights, "ffn_moe_weights_norm", il);
42604269

42614270
// compute expert outputs
42624271
ggml_tensor * moe_out;
@@ -4269,19 +4278,30 @@ struct llm_build_context {
42694278
ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
42704279
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
42714280

4272-
cur_expert = ggml_mul(ctx0,
4273-
ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur),
4274-
ggml_silu(ctx0,
4275-
ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur))); // [n_tokens, n_embd]
4281+
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur);
4282+
cb(cur_up, "ffn_up", il);
4283+
4284+
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur);
4285+
cb(cur_gate, "ffn_gate", il);
4286+
4287+
cur_gate = ggml_silu(ctx0, cur_gate);
4288+
cb(cur_gate, "ffn_silu", il);
4289+
4290+
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
4291+
cb(cur_expert, "ffn_gate_par", il);
42764292

42774293
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
4294+
cb(cur_expert, "ffn_down", il);
4295+
42784296
cur_expert = ggml_mul(ctx0, cur_expert,
42794297
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
4298+
cb(cur_expert, "ffn_moe_weighted", il);
42804299

42814300
if (i == 0) {
42824301
moe_out = cur_expert;
42834302
} else {
42844303
moe_out = ggml_add(ctx0, moe_out, cur_expert);
4304+
cb(moe_out, "ffn_moe_out", il);
42854305
}
42864306
}
42874307

@@ -5540,6 +5560,14 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
55405560
{ "ffn_relu", OFFLOAD_FUNC },
55415561
{ "ffn_sqr(relu)", OFFLOAD_FUNC },
55425562

5563+
{ "ffn_moe_logits", OFFLOAD_FUNC },
5564+
{ "ffn_moe_probs", OFFLOAD_FUNC },
5565+
{ "ffn_moe_weights", OFFLOAD_FUNC_NOP },
5566+
{ "ffn_moe_weights_sum", OFFLOAD_FUNC },
5567+
{ "ffn_moe_weights_norm", OFFLOAD_FUNC },
5568+
{ "ffn_moe_weighted", OFFLOAD_FUNC },
5569+
{ "ffn_moe_out", OFFLOAD_FUNC },
5570+
55435571
{ "l_out", OFFLOAD_FUNC },
55445572

55455573
{ "result_norm", OFFLOAD_FUNC_EMB },

0 commit comments

Comments
 (0)