27
27
from vllm_ascend .ascend_forward_context import (FusedMoEState ,
28
28
_get_fused_moe_state )
29
29
from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
30
- AscendUnquantizedFusedMoEMethod ,
31
- unified_apply_mlp )
30
+ AscendUnquantizedFusedMoEMethod )
32
31
from vllm_ascend .ops .layers .experts_selector import select_experts
32
+ from vllm_ascend .ops .layers .moe_mlp import unified_apply_mlp
33
33
from vllm_ascend .utils import AscendSocVersion , adapt_patch
34
34
35
35
adapt_patch (True )
@@ -129,36 +129,38 @@ def capture_register(dispatcher_instance):
129
129
with_quant = False )
130
130
131
131
with patch ('torch.distributed.get_rank' , return_value = 0 ), \
132
- patch ('torch.distributed.get_world_size' , return_value = 4 ), \
133
- patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
134
- patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
135
- patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
136
- patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
137
- patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
138
- patch ('vllm.model_executor.layers.fused_moe.layer.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
139
- patch ('torch.distributed.all_gather' ), \
140
- patch ('torch.distributed.all_to_all_single' ), \
141
- patch ('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce' ), \
142
- patch ('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter' ), \
143
- patch ('vllm.model_executor.layers.fused_moe.config.get_dp_group' ,
144
- return_value = mock_dp_and_tp_group (mocker )), \
145
- patch ('vllm_ascend.ops.fused_moe.get_ascend_config' ,
146
- return_value = MagicMock (
147
- torchair_graph_config = MagicMock (enabled = False , enable_multistream_moe = False ),
148
- expert_map_path = None
149
- )), \
150
- patch ('vllm_ascend.ops.fused_moe.determine_expert_map' ,
151
- return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
152
- patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
153
- return_value = mock_forward_context_obj ), \
132
+ patch ('torch.distributed.get_world_size' , return_value = 4 ), \
133
+ patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
134
+ patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
135
+ patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
136
+ patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
137
+ patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
138
+ patch ('vllm.model_executor.layers.fused_moe.layer.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
139
+ patch ('torch.distributed.all_gather' ), \
140
+ patch ('torch.distributed.all_to_all_single' ), \
141
+ patch ('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce' ), \
142
+ patch ('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter' ), \
143
+ patch ('vllm.model_executor.layers.fused_moe.config.get_dp_group' ,
144
+ return_value = mock_dp_and_tp_group (mocker )), \
145
+ patch ('vllm_ascend.ops.fused_moe.get_ascend_config' ,
146
+ return_value = MagicMock (
147
+ torchair_graph_config = MagicMock (enabled = False , enable_multistream_moe = False ),
148
+ expert_map_path = None
149
+ )), \
150
+ patch ('vllm_ascend.ops.fused_moe.determine_expert_map' ,
151
+ return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
152
+ patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
153
+ return_value = mock_forward_context_obj ), \
154
154
patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ,
155
- return_value = MagicMock (
156
- parallel_config = MagicMock (tensor_parallel_size = 2 ),
157
- scheduler_config = MagicMock (max_num_seqs = 4 ),
158
- model_config = MagicMock (max_model_len = 2048 )
159
- )), \
155
+ return_value = MagicMock (
156
+ parallel_config = MagicMock (tensor_parallel_size = 2 ),
157
+ scheduler_config = MagicMock (max_num_seqs = 4 ),
158
+ model_config = MagicMock (max_model_len = 2048 )
159
+ )), \
160
160
patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ), \
161
- patch .object (token_dispatcher_module , 'setup_token_dispatchers' , mock_setup_token_dispatchers ):
161
+ patch .object (token_dispatcher_module , 'setup_token_dispatchers' , mock_setup_token_dispatchers ), \
162
+ patch ('vllm_ascend.ops.layers.moe_mlp.get_forward_context' ,
163
+ return_value = mock_forward_context_obj ):
162
164
163
165
yield {
164
166
'mock_forward_context_obj' : mock_forward_context_obj ,
@@ -441,12 +443,11 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
441
443
442
444
assert result .shape == expected_shape
443
445
444
- @pytest .mark .parametrize ("others_param" ,
445
- [[16 , False ], [1 , True ], [1 , False ], [4 , False ]])
446
+ @pytest .mark .parametrize ("others_param" , [16 , 1 , 4 ])
446
447
def test_apply_with_expert_map (self , moe_method , mock_dist_env ,
447
448
mock_moe_env , others_param ):
448
449
449
- ep_size , alltoall_buffer = others_param
450
+ ep_size = others_param
450
451
is_prefill = False
451
452
452
453
if ep_size == 1 :
@@ -464,9 +465,7 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
464
465
with_quant = False ,
465
466
token_dispatcher = selected_token_dispatcher )
466
467
467
- with patch ("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER" ,
468
- alltoall_buffer ), \
469
- patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
468
+ with patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
470
469
patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ):
471
470
472
471
expert_map = torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ])
@@ -475,8 +474,6 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
475
474
if ep_size == 1 :
476
475
x = x .view (- 1 , 2 )
477
476
router_logits = torch .randn (8 , 8 )
478
- if alltoall_buffer :
479
- moe_method .max_model_len = 1
480
477
layer = MagicMock ()
481
478
482
479
local_num_experts = 2
@@ -529,26 +526,21 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
529
526
530
527
class TestUnifiedApplyMLP (TestBase ):
531
528
532
- @patch ('vllm_ascend.ops.fused_moe.get_forward_context' )
533
- @patch ('vllm_ascend.ops.fused_moe.get_mc2_group' )
534
- @patch ('vllm_ascend.ops.fused_moe.is_310p' )
529
+ @patch ('vllm_ascend.ops.layers.moe_mlp.get_forward_context' )
530
+ @patch ('vllm_ascend.ops.layers.moe_mlp.is_310p' )
535
531
@patch ('torch_npu.npu_grouped_matmul' )
536
532
@patch ('torch_npu.npu_dynamic_quant' )
537
533
@patch ('torch_npu.npu_dequant_swiglu_quant' )
538
534
def test_unified_apply_mlp_with_quantization_mc2 (self , mock_npu_dequant ,
539
535
mock_npu_dynamic_quant ,
540
536
mock_npu_grouped_matmul ,
541
537
mock_is_310p ,
542
- mock_get_mc2_group ,
543
538
mock_get_forward_context ):
544
539
545
540
mock_forward_context = MagicMock ()
546
541
mock_forward_context .fused_moe_state = FusedMoEState .MC2
547
542
mock_get_forward_context .return_value = mock_forward_context
548
543
549
- mock_mc2_group = MagicMock ()
550
- mock_get_mc2_group .return_value = mock_mc2_group
551
-
552
544
mock_is_310p .return_value = False
553
545
554
546
mock_npu_dynamic_quant .return_value = (torch .randint (- 128 ,
@@ -601,7 +593,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
601
593
602
594
self .assertEqual (result .dtype , torch .bfloat16 )
603
595
604
- @patch ('vllm_ascend.ops.fused_moe .is_310p' )
596
+ @patch ('vllm_ascend.ops.layers.moe_mlp .is_310p' )
605
597
@patch ('torch_npu.npu_grouped_matmul' )
606
598
@patch ('torch_npu.npu_swiglu' )
607
599
@patch ('torch_npu.npu_dynamic_quant' )
@@ -643,7 +635,7 @@ def test_unified_apply_mlp_without_quantization(self,
643
635
self .assertEqual (result .shape , hidden_states .shape )
644
636
self .assertEqual (result .dtype , torch .float16 )
645
637
646
- @patch ('vllm_ascend.ops.fused_moe .get_forward_context' )
638
+ @patch ('vllm_ascend.ops.layers.moe_mlp .get_forward_context' )
647
639
@patch ('torch_npu.npu_grouped_matmul' )
648
640
@patch ('torch_npu.npu_swiglu' )
649
641
@patch ('torch_npu.npu_dynamic_quant' )
@@ -703,7 +695,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
703
695
self .assertEqual (result .shape , hidden_states .shape )
704
696
self .assertEqual (result .dtype , torch .bfloat16 )
705
697
706
- @patch ('vllm_ascend.ops.fused_moe .is_310p' )
698
+ @patch ('vllm_ascend.ops.layers.moe_mlp .is_310p' )
707
699
@patch ('torch_npu.npu_grouped_matmul' )
708
700
@patch ('torch_npu.npu_swiglu' )
709
701
@patch ('torch_npu.npu_dynamic_quant' )
0 commit comments