Skip to content

Commit 3a5fc5e

Browse files
[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)
### What this PR does / why we need it? There are a lot of redundant codes related to moe here, and the structure is not very clear. We did the following things: we have placed the relatively independent code related to apply_mlp into a separate file; removed the environment variables of alltoall_buffer and alltoall_seq. Remove the code related to alltoall_buffer and alltoall_seq, and retain the sole TokenDispatcher inheritance class. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e&ut - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@4071c76 --------- Signed-off-by: Pr0Wh1teGivee <[email protected]> Signed-off-by: weijinqian_v1 <[email protected]> Co-authored-by: weijinqian0 <[email protected]>
1 parent 20ae712 commit 3a5fc5e

File tree

13 files changed

+417
-1237
lines changed

13 files changed

+417
-1237
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ jobs:
279279
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
280280
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
281281
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
282-
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
283282
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
284283
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
285284
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,13 @@ def test_models_distributed_pangu():
108108
]
109109
max_tokens = 5
110110

111-
with VllmRunner(
112-
snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
113-
max_model_len=8192,
114-
enforce_eager=True,
115-
dtype="auto",
116-
tensor_parallel_size=2,
117-
distributed_executor_backend="mp",
118-
) as vllm_model:
111+
with VllmRunner(snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
112+
max_model_len=8192,
113+
enforce_eager=True,
114+
dtype="auto",
115+
tensor_parallel_size=2,
116+
distributed_executor_backend="mp",
117+
enable_expert_parallel=True) as vllm_model:
119118
vllm_model.generate_greedy(example_prompts, max_tokens)
120119

121120

@@ -141,28 +140,6 @@ def test_models_distributed_topk() -> None:
141140
vllm_model.generate(example_prompts, sampling_params)
142141

143142

144-
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"})
145-
def test_models_distributed_alltoallv() -> None:
146-
example_prompts = [
147-
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
148-
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
149-
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
150-
]
151-
dtype = "half"
152-
sampling_params = SamplingParams(max_tokens=5,
153-
temperature=0.0,
154-
top_k=50,
155-
top_p=0.9)
156-
157-
with VllmRunner(
158-
"deepseek-ai/DeepSeek-V2-Lite",
159-
dtype=dtype,
160-
tensor_parallel_size=2,
161-
distributed_executor_backend="mp",
162-
) as vllm_model:
163-
vllm_model.generate(example_prompts, sampling_params)
164-
165-
166143
def test_models_distributed_Qwen3_W8A8():
167144
example_prompts = [
168145
"Hello, my name is",
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
from unittest.mock import patch
16+
17+
import torch
18+
19+
from tests.ut.base import TestBase
20+
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
21+
22+
23+
class TestFusedExpertsMoGE(TestBase):
24+
25+
def test_fused_experts_moge(self):
26+
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
27+
patch('torch_npu.npu_swiglu') as mock_swiglu, \
28+
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
29+
30+
mock_is_310p.return_value = False
31+
32+
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
33+
torch.randn(x[0].shape[0], weight[0].shape[1])
34+
]
35+
36+
mock_swiglu.side_effect = lambda x: x
37+
38+
hidden_states = torch.randn(4, 128)
39+
w1 = torch.randn(4, 256, 128)
40+
w2 = torch.randn(4, 128, 128)
41+
topk_weights = torch.rand(4, 1)
42+
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
43+
top_k = 1
44+
global_num_experts = 4
45+
46+
moe_parallel_config = type(
47+
'MockConfig', (), {
48+
'ep_size': 1,
49+
'tp_size': 1,
50+
'dp_size': 1,
51+
'tp_rank': 0,
52+
'dp_rank': 0,
53+
'ep_rank': 0,
54+
'use_ep': True
55+
})()
56+
57+
output = fused_experts_moge(
58+
hidden_states=hidden_states,
59+
w1=w1,
60+
w2=w2,
61+
moe_parallel_config=moe_parallel_config,
62+
topk_weights=topk_weights,
63+
topk_ids=topk_ids,
64+
top_k=top_k,
65+
global_num_experts=global_num_experts,
66+
apply_router_weight_on_input=True,
67+
)
68+
69+
self.assertEqual(output.shape, (4, 128))

tests/ut/ops/test_fused_ops.py

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from vllm_ascend.ascend_forward_context import (FusedMoEState,
2828
_get_fused_moe_state)
2929
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
30-
AscendUnquantizedFusedMoEMethod,
31-
unified_apply_mlp)
30+
AscendUnquantizedFusedMoEMethod)
3231
from vllm_ascend.ops.layers.experts_selector import select_experts
32+
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
3333
from vllm_ascend.utils import AscendSocVersion, adapt_patch
3434

3535
adapt_patch(True)
@@ -129,36 +129,38 @@ def capture_register(dispatcher_instance):
129129
with_quant=False)
130130

131131
with patch('torch.distributed.get_rank', return_value=0), \
132-
patch('torch.distributed.get_world_size', return_value=4), \
133-
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
134-
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
135-
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
136-
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
137-
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
138-
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
139-
patch('torch.distributed.all_gather'), \
140-
patch('torch.distributed.all_to_all_single'), \
141-
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
142-
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
143-
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
144-
return_value=mock_dp_and_tp_group(mocker)), \
145-
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
146-
return_value=MagicMock(
147-
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
148-
expert_map_path=None
149-
)), \
150-
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
151-
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
152-
patch('vllm_ascend.ops.fused_moe.get_forward_context',
153-
return_value=mock_forward_context_obj), \
132+
patch('torch.distributed.get_world_size', return_value=4), \
133+
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
134+
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
135+
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
136+
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
137+
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
138+
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
139+
patch('torch.distributed.all_gather'), \
140+
patch('torch.distributed.all_to_all_single'), \
141+
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
142+
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
143+
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
144+
return_value=mock_dp_and_tp_group(mocker)), \
145+
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
146+
return_value=MagicMock(
147+
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
148+
expert_map_path=None
149+
)), \
150+
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
151+
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
152+
patch('vllm_ascend.ops.fused_moe.get_forward_context',
153+
return_value=mock_forward_context_obj), \
154154
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
155-
return_value=MagicMock(
156-
parallel_config=MagicMock(tensor_parallel_size=2),
157-
scheduler_config=MagicMock(max_num_seqs=4),
158-
model_config=MagicMock(max_model_len=2048)
159-
)), \
155+
return_value=MagicMock(
156+
parallel_config=MagicMock(tensor_parallel_size=2),
157+
scheduler_config=MagicMock(max_num_seqs=4),
158+
model_config=MagicMock(max_model_len=2048)
159+
)), \
160160
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
161-
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers):
161+
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
162+
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
163+
return_value=mock_forward_context_obj):
162164

163165
yield {
164166
'mock_forward_context_obj': mock_forward_context_obj,
@@ -441,12 +443,11 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
441443

442444
assert result.shape == expected_shape
443445

444-
@pytest.mark.parametrize("others_param",
445-
[[16, False], [1, True], [1, False], [4, False]])
446+
@pytest.mark.parametrize("others_param", [16, 1, 4])
446447
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
447448
mock_moe_env, others_param):
448449

449-
ep_size, alltoall_buffer = others_param
450+
ep_size = others_param
450451
is_prefill = False
451452

452453
if ep_size == 1:
@@ -464,9 +465,7 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
464465
with_quant=False,
465466
token_dispatcher=selected_token_dispatcher)
466467

467-
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
468-
alltoall_buffer), \
469-
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
468+
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
470469
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
471470

472471
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
@@ -475,8 +474,6 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
475474
if ep_size == 1:
476475
x = x.view(-1, 2)
477476
router_logits = torch.randn(8, 8)
478-
if alltoall_buffer:
479-
moe_method.max_model_len = 1
480477
layer = MagicMock()
481478

482479
local_num_experts = 2
@@ -529,26 +526,21 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
529526

530527
class TestUnifiedApplyMLP(TestBase):
531528

532-
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
533-
@patch('vllm_ascend.ops.fused_moe.get_mc2_group')
534-
@patch('vllm_ascend.ops.fused_moe.is_310p')
529+
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
530+
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
535531
@patch('torch_npu.npu_grouped_matmul')
536532
@patch('torch_npu.npu_dynamic_quant')
537533
@patch('torch_npu.npu_dequant_swiglu_quant')
538534
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
539535
mock_npu_dynamic_quant,
540536
mock_npu_grouped_matmul,
541537
mock_is_310p,
542-
mock_get_mc2_group,
543538
mock_get_forward_context):
544539

545540
mock_forward_context = MagicMock()
546541
mock_forward_context.fused_moe_state = FusedMoEState.MC2
547542
mock_get_forward_context.return_value = mock_forward_context
548543

549-
mock_mc2_group = MagicMock()
550-
mock_get_mc2_group.return_value = mock_mc2_group
551-
552544
mock_is_310p.return_value = False
553545

554546
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
@@ -601,7 +593,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
601593

602594
self.assertEqual(result.dtype, torch.bfloat16)
603595

604-
@patch('vllm_ascend.ops.fused_moe.is_310p')
596+
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
605597
@patch('torch_npu.npu_grouped_matmul')
606598
@patch('torch_npu.npu_swiglu')
607599
@patch('torch_npu.npu_dynamic_quant')
@@ -643,7 +635,7 @@ def test_unified_apply_mlp_without_quantization(self,
643635
self.assertEqual(result.shape, hidden_states.shape)
644636
self.assertEqual(result.dtype, torch.float16)
645637

646-
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
638+
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
647639
@patch('torch_npu.npu_grouped_matmul')
648640
@patch('torch_npu.npu_swiglu')
649641
@patch('torch_npu.npu_dynamic_quant')
@@ -703,7 +695,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
703695
self.assertEqual(result.shape, hidden_states.shape)
704696
self.assertEqual(result.dtype, torch.bfloat16)
705697

706-
@patch('vllm_ascend.ops.fused_moe.is_310p')
698+
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
707699
@patch('torch_npu.npu_grouped_matmul')
708700
@patch('torch_npu.npu_swiglu')
709701
@patch('torch_npu.npu_dynamic_quant')

tests/ut/ops/test_token_dispatcher.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,13 @@
1717

1818
from unittest.mock import MagicMock, PropertyMock, patch
1919

20-
import pytest
2120
import torch
22-
from pytest_mock import MockerFixture
2321

24-
from tests.ut.base import PytestBase, TestBase
22+
from tests.ut.base import TestBase
2523
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
26-
AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig,
27-
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
28-
TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher,
29-
get_token_dispatcher, setup_token_dispatchers)
30-
31-
32-
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
33-
34-
@pytest.fixture
35-
def config(self):
36-
config = MoEDispatcherConfig()
37-
config.set_num_local_experts(2)
38-
config.set_num_moe_experts(4)
39-
config.set_moe_pad_expert_input_to_capacity(False)
40-
config.set_moe_expert_capacity_factor(None)
41-
config.set_moe_router_topk(2)
42-
config.set_moe_grouped_gemm(False)
43-
config.set_group_topk(0)
44-
config.set_num_groups(1)
45-
config.set_is_fused(False)
46-
return config.build()
47-
48-
def mock_ep_group(self, mocker):
49-
mock_group = mocker.MagicMock()
50-
mock_group.rank_in_group = 0
51-
mock_group.world_size = 2
52-
mock_group.device_group = "mock_group"
53-
return mock_group
54-
55-
@pytest.fixture
56-
def dispatcher(self, config, mocker: MockerFixture):
57-
mocker.patch(
58-
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
59-
return_value=self.mock_ep_group(mocker))
60-
mocker.patch("torch.npu.current_device", return_value="cpu")
61-
mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock)
62-
return MoEAlltoAllSeqOverLapDispatcher(config)
63-
64-
def test_initialization(self, dispatcher, config):
65-
assert dispatcher.num_local_experts == config.num_local_experts
66-
assert dispatcher.num_experts == config.num_moe_experts
67-
assert dispatcher.local_expert_indices == [0, 1]
68-
assert dispatcher.ep_rank == 0
69-
assert dispatcher.ep_size == 2
70-
assert dispatcher.overlap_stream is not None
24+
AscendSocVersion, TokenDispatcherWithAll2AllV,
25+
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
26+
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
7127

7228

7329
class TestTokenDispatcherWithMC2(TestBase):

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,7 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
353353
else:
354354
assert result.shape == x.shape
355355

356-
@pytest.mark.parametrize("others_param",
357-
[[16, False], [1, True], [1, False], [4, False]])
356+
@pytest.mark.parametrize("others_param", [16, 1, 4])
358357
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
359358
mock_moe_env, others_param):
360359
"""
@@ -363,22 +362,18 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
363362
3 test use_select_experts and fused_experts_with_all2all
364363
4 test use_select_experts and fused_experts
365364
"""
366-
ep_size, alltoall_buffer = others_param
365+
ep_size = others_param
367366
is_prefill = False
368367
forward_context = MagicMock(
369368
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
370-
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.MOE_ALL2ALL_BUFFER",
371-
alltoall_buffer), \
372-
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
369+
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
373370
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
374371
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
375372
moe_method.ep_size = ep_size
376373
x = torch.randn(8, 2, 2)
377374
if ep_size == 1:
378375
x = x.view(-1, 2)
379376
router_logits = torch.randn(8, 8)
380-
if alltoall_buffer:
381-
moe_method.max_model_len = 1
382377
layer = MagicMock()
383378
layer.w13_weight = torch.randn(8, 16, 1)
384379
layer.w2_weight = torch.randn(16, 8, 1)

0 commit comments

Comments
 (0)