Skip to content

Commit 50e24bb

Browse files
committed
support Hunyuan MoE
1 parent 746f67a commit 50e24bb

File tree

12 files changed

+500
-248
lines changed

12 files changed

+500
-248
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ set(core_files src/backend.cpp
5555
src/unicode-data.cpp
5656
src/vision_process.cpp
5757
src/audio_process.cpp
58-
models/qwen.cpp)
58+
models/qwen.cpp
59+
models/hunyuan.cpp)
5960

6061
add_library(libchatllm SHARED EXCLUDE_FROM_ALL src/main.cpp ${core_files})
6162
target_link_libraries(libchatllm PRIVATE ggml)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16+
* 2025-06-30: Hunyuan-A13B
1617
* 2025-06-21: [I can hear](./docs/multimodal.md): Qwen2-Audio
1718
* 2025-06-10: SmolVLM2
1819
* 2025-06-07: MiniCPM4

convert.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pathlib import Path
1717
from typing import IO, Any, Iterable, List, Optional, Tuple
1818
import numpy as np
19-
import math
19+
import math, gc
2020

2121
import torch
2222
from torch import nn
@@ -184,6 +184,7 @@ class ModelType(Enum):
184184
TeleChat2 = 0x1e00
185185

186186
HunYuanDense = 0x1f00
187+
HunYuanMoEV1 = 0x1f01
187188

188189
MoonLight = 0x2000
189190

@@ -605,6 +606,8 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_
605606
dump_tensor(f, name, tensor, tensor_ggml_type)
606607
tensor_info.append((name, tensor.shape, tensor_ggml_type.name))
607608

609+
gc.collect()
610+
608611
print(tabulate(tensor_info, headers=["name", "shape", "dtype"]))
609612

610613
if len(tensor_info) != len(weight_names):
@@ -6521,6 +6524,92 @@ def get_weight_names(config):
65216524

65226525
return weight_names
65236526

6527+
class HunYuanMoEV1Converter(BaseConverter):
6528+
MODEL_TYPE = ModelType.HunYuanMoEV1
6529+
6530+
@classmethod
6531+
def state_dict_pp(cls, config, state_dict):
6532+
new_dict = {}
6533+
6534+
for name in state_dict:
6535+
tensor: torch.Tensor = state_dict[name]
6536+
new_name = name
6537+
new_name = new_name.replace('.mlp.gate.wg.', '.mlp.gate.')
6538+
new_name = new_name.replace('.shared_mlp.', '.shared_expert.')
6539+
6540+
new_dict[new_name] = tensor
6541+
6542+
return new_dict
6543+
6544+
@staticmethod
6545+
def dump_config(f, config, ggml_type):
6546+
assert config.tie_word_embeddings, "tie_word_embeddings must be True"
6547+
assert config.attention_bias == False, "attention_bias must be False"
6548+
assert config.mlp_bias == False, "mlp_bias must be False"
6549+
assert not config.use_cla, "use_cla must be False"
6550+
assert not config.use_mla, "use_mla must be False"
6551+
assert config.rope_scaling['type'] == 'dynamic', "rope_scaling['type'] must be 'dynamic'"
6552+
assert config.use_qk_norm, "use_qk_norm must be True"
6553+
assert config.rope_scaling['alpha'] > 0, "rope_scaling['alpha'] must be > 0"
6554+
assert config.moe_layer_num_skipped == 0
6555+
assert config.use_mixed_mlp_moe
6556+
assert len(set(config.moe_intermediate_size)) == 1
6557+
assert len(set(config.moe_topk)) == 1
6558+
assert len(set(config.num_shared_expert)) == 1
6559+
assert config.attention_head_dim == config.hidden_size / config.num_attention_heads
6560+
6561+
head_dim = config.attention_head_dim
6562+
config.rope_theta = config.rope_theta * config.rope_scaling['alpha'] ** (head_dim / (head_dim - 2))
6563+
6564+
dump_llama_like_config(f, config, ggml_type)
6565+
6566+
config_values = [
6567+
config.num_key_value_heads,
6568+
config.num_experts,
6569+
6570+
list(set(config.moe_intermediate_size))[0],
6571+
list(set(config.moe_topk))[0],
6572+
list(set(config.num_shared_expert))[0],
6573+
]
6574+
f.write(struct.pack("<" + "i" * len(config_values), *config_values))
6575+
6576+
config_values = [
6577+
config.rope_theta,
6578+
]
6579+
f.write(struct.pack("<" + "f" * len(config_values), *config_values))
6580+
6581+
@staticmethod
6582+
def get_weight_names(config):
6583+
weight_names = ["model.embed_tokens.weight"]
6584+
for i in range(config.num_hidden_layers):
6585+
for j in range(config.num_experts):
6586+
weight_names += [
6587+
f"model.layers.{i}.mlp.experts.{j}.down_proj.weight",
6588+
f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight",
6589+
f"model.layers.{i}.mlp.experts.{j}.up_proj.weight",
6590+
]
6591+
6592+
weight_names += [
6593+
f"model.layers.{i}.mlp.gate.weight",
6594+
f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
6595+
f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
6596+
f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
6597+
f"model.layers.{i}.input_layernorm.weight",
6598+
f"model.layers.{i}.post_attention_layernorm.weight",
6599+
f"model.layers.{i}.self_attn.k_proj.weight",
6600+
f"model.layers.{i}.self_attn.o_proj.weight",
6601+
f"model.layers.{i}.self_attn.q_proj.weight",
6602+
f"model.layers.{i}.self_attn.v_proj.weight",
6603+
f"model.layers.{i}.self_attn.key_layernorm.weight",
6604+
f"model.layers.{i}.self_attn.query_layernorm.weight",
6605+
]
6606+
6607+
weight_names += [
6608+
"model.norm.weight",
6609+
]
6610+
6611+
return weight_names
6612+
65246613
class SolarConverter(BaseConverter):
65256614
MODEL_TYPE = ModelType.SolarPro
65266615

@@ -7380,6 +7469,8 @@ def main():
73807469
(isinstance(config.num_experts, list) and max(config.num_experts) > 1)):
73817470
raise Exception('HunYuanForCausalLM: only dense model is supported')
73827471
HunYuanDenseConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
7472+
elif arch == 'HunYuanMoEV1ForCausalLM':
7473+
HunYuanMoEV1Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
73837474
elif arch == 'InstellaForCausalLM':
73847475
InstellaConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
73857476
elif arch == 'DeciLMForCausalLM':

docs/models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878

7979
* HunYuan (`HunYuanForCausalLM`)
8080
* [x] Dense: [Instruct-7B](https://huggingface.co/tencent/Hunyuan-7B-Instruct)
81+
* [x] MoE: [A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct/tree/202c9758065873e0ac7c80211e6275593f165442)
8182

8283
* Instella (`InstellaForCausalLM`)
8384
* [x] [Instruct-3B](https://huggingface.co/amd/Instella-3B-Instruct)

models/grok.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const int NUM_EXPERTS = 8;
2525
const int EXPERTS_PER_TOK = 2;
2626

2727
// make it easy to test with different number of experts.
28-
#define EFFECTIVE_EXPERTS_PER_TOK EXPERTS_PER_TOK
28+
const int EFFECTIVE_EXPERTS_PER_TOK = EXPERTS_PER_TOK;
2929

3030
class GrokBaseAttention : public BaseAttention
3131
{

0 commit comments

Comments
 (0)