|
16 | 16 | from pathlib import Path
|
17 | 17 | from typing import IO, Any, Iterable, List, Optional, Tuple
|
18 | 18 | import numpy as np
|
19 |
| -import math |
| 19 | +import math, gc |
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 | from torch import nn
|
@@ -184,6 +184,7 @@ class ModelType(Enum):
|
184 | 184 | TeleChat2 = 0x1e00
|
185 | 185 |
|
186 | 186 | HunYuanDense = 0x1f00
|
| 187 | + HunYuanMoEV1 = 0x1f01 |
187 | 188 |
|
188 | 189 | MoonLight = 0x2000
|
189 | 190 |
|
@@ -605,6 +606,8 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_
|
605 | 606 | dump_tensor(f, name, tensor, tensor_ggml_type)
|
606 | 607 | tensor_info.append((name, tensor.shape, tensor_ggml_type.name))
|
607 | 608 |
|
| 609 | + gc.collect() |
| 610 | + |
608 | 611 | print(tabulate(tensor_info, headers=["name", "shape", "dtype"]))
|
609 | 612 |
|
610 | 613 | if len(tensor_info) != len(weight_names):
|
@@ -6521,6 +6524,92 @@ def get_weight_names(config):
|
6521 | 6524 |
|
6522 | 6525 | return weight_names
|
6523 | 6526 |
|
| 6527 | +class HunYuanMoEV1Converter(BaseConverter): |
| 6528 | + MODEL_TYPE = ModelType.HunYuanMoEV1 |
| 6529 | + |
| 6530 | + @classmethod |
| 6531 | + def state_dict_pp(cls, config, state_dict): |
| 6532 | + new_dict = {} |
| 6533 | + |
| 6534 | + for name in state_dict: |
| 6535 | + tensor: torch.Tensor = state_dict[name] |
| 6536 | + new_name = name |
| 6537 | + new_name = new_name.replace('.mlp.gate.wg.', '.mlp.gate.') |
| 6538 | + new_name = new_name.replace('.shared_mlp.', '.shared_expert.') |
| 6539 | + |
| 6540 | + new_dict[new_name] = tensor |
| 6541 | + |
| 6542 | + return new_dict |
| 6543 | + |
| 6544 | + @staticmethod |
| 6545 | + def dump_config(f, config, ggml_type): |
| 6546 | + assert config.tie_word_embeddings, "tie_word_embeddings must be True" |
| 6547 | + assert config.attention_bias == False, "attention_bias must be False" |
| 6548 | + assert config.mlp_bias == False, "mlp_bias must be False" |
| 6549 | + assert not config.use_cla, "use_cla must be False" |
| 6550 | + assert not config.use_mla, "use_mla must be False" |
| 6551 | + assert config.rope_scaling['type'] == 'dynamic', "rope_scaling['type'] must be 'dynamic'" |
| 6552 | + assert config.use_qk_norm, "use_qk_norm must be True" |
| 6553 | + assert config.rope_scaling['alpha'] > 0, "rope_scaling['alpha'] must be > 0" |
| 6554 | + assert config.moe_layer_num_skipped == 0 |
| 6555 | + assert config.use_mixed_mlp_moe |
| 6556 | + assert len(set(config.moe_intermediate_size)) == 1 |
| 6557 | + assert len(set(config.moe_topk)) == 1 |
| 6558 | + assert len(set(config.num_shared_expert)) == 1 |
| 6559 | + assert config.attention_head_dim == config.hidden_size / config.num_attention_heads |
| 6560 | + |
| 6561 | + head_dim = config.attention_head_dim |
| 6562 | + config.rope_theta = config.rope_theta * config.rope_scaling['alpha'] ** (head_dim / (head_dim - 2)) |
| 6563 | + |
| 6564 | + dump_llama_like_config(f, config, ggml_type) |
| 6565 | + |
| 6566 | + config_values = [ |
| 6567 | + config.num_key_value_heads, |
| 6568 | + config.num_experts, |
| 6569 | + |
| 6570 | + list(set(config.moe_intermediate_size))[0], |
| 6571 | + list(set(config.moe_topk))[0], |
| 6572 | + list(set(config.num_shared_expert))[0], |
| 6573 | + ] |
| 6574 | + f.write(struct.pack("<" + "i" * len(config_values), *config_values)) |
| 6575 | + |
| 6576 | + config_values = [ |
| 6577 | + config.rope_theta, |
| 6578 | + ] |
| 6579 | + f.write(struct.pack("<" + "f" * len(config_values), *config_values)) |
| 6580 | + |
| 6581 | + @staticmethod |
| 6582 | + def get_weight_names(config): |
| 6583 | + weight_names = ["model.embed_tokens.weight"] |
| 6584 | + for i in range(config.num_hidden_layers): |
| 6585 | + for j in range(config.num_experts): |
| 6586 | + weight_names += [ |
| 6587 | + f"model.layers.{i}.mlp.experts.{j}.down_proj.weight", |
| 6588 | + f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight", |
| 6589 | + f"model.layers.{i}.mlp.experts.{j}.up_proj.weight", |
| 6590 | + ] |
| 6591 | + |
| 6592 | + weight_names += [ |
| 6593 | + f"model.layers.{i}.mlp.gate.weight", |
| 6594 | + f"model.layers.{i}.mlp.shared_expert.down_proj.weight", |
| 6595 | + f"model.layers.{i}.mlp.shared_expert.gate_proj.weight", |
| 6596 | + f"model.layers.{i}.mlp.shared_expert.up_proj.weight", |
| 6597 | + f"model.layers.{i}.input_layernorm.weight", |
| 6598 | + f"model.layers.{i}.post_attention_layernorm.weight", |
| 6599 | + f"model.layers.{i}.self_attn.k_proj.weight", |
| 6600 | + f"model.layers.{i}.self_attn.o_proj.weight", |
| 6601 | + f"model.layers.{i}.self_attn.q_proj.weight", |
| 6602 | + f"model.layers.{i}.self_attn.v_proj.weight", |
| 6603 | + f"model.layers.{i}.self_attn.key_layernorm.weight", |
| 6604 | + f"model.layers.{i}.self_attn.query_layernorm.weight", |
| 6605 | + ] |
| 6606 | + |
| 6607 | + weight_names += [ |
| 6608 | + "model.norm.weight", |
| 6609 | + ] |
| 6610 | + |
| 6611 | + return weight_names |
| 6612 | + |
6524 | 6613 | class SolarConverter(BaseConverter):
|
6525 | 6614 | MODEL_TYPE = ModelType.SolarPro
|
6526 | 6615 |
|
@@ -7380,6 +7469,8 @@ def main():
|
7380 | 7469 | (isinstance(config.num_experts, list) and max(config.num_experts) > 1)):
|
7381 | 7470 | raise Exception('HunYuanForCausalLM: only dense model is supported')
|
7382 | 7471 | HunYuanDenseConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
|
| 7472 | + elif arch == 'HunYuanMoEV1ForCausalLM': |
| 7473 | + HunYuanMoEV1Converter.convert(config, model_files, vocab, ggml_type, args.save_path) |
7383 | 7474 | elif arch == 'InstellaForCausalLM':
|
7384 | 7475 | InstellaConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
|
7385 | 7476 | elif arch == 'DeciLMForCausalLM':
|
|
0 commit comments