Skip to content

Commit e7618d9

Browse files
authored
[2/N][Refactor][Qwen3-Next] remove redundant methods and patch methods in Qwen3NextGatedDeltaNet (#3082)
### What this PR does / why we need it? remove redundant methods and patch methods in Qwen3NextGatedDeltaNet involved causal_conv1d_fn, causal_conv1d_update_npu, fused_gdn_gating, fused_reccrrent_gated_delta_rule, torch_chunk_gated_delta_rule, RMSNormGated ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` def main(): prompts = [ "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` CI passed with new added/existing test. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@5aeb925 --------- Signed-off-by: Icey <[email protected]>
1 parent eb205d9 commit e7618d9

File tree

6 files changed

+669
-982
lines changed

6 files changed

+669
-982
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 12 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Optional
77

88
import torch
9-
import torch.nn.functional as F
109
from einops import rearrange
1110
from torch import nn
1211
from transformers.activations import ACT2FN
@@ -19,6 +18,10 @@
1918
get_tensor_model_parallel_rank,
2019
get_tensor_model_parallel_world_size)
2120
from vllm.forward_context import ForwardContext, get_forward_context
21+
from vllm.model_executor.layers.fla.ops import RMSNormGated
22+
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
23+
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
24+
fused_recurrent_gated_delta_rule
2225
from vllm.model_executor.layers.fused_moe import FusedMoE
2326
# yapf conflicts with isort for this block
2427
# yapf: disable
@@ -34,6 +37,8 @@
3437
mamba_v2_sharded_weight_loader
3538
from vllm.model_executor.layers.mamba.mamba_utils import (
3639
MambaStateDtypeCalculator, MambaStateShapeCalculator)
40+
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
41+
causal_conv1d_fn, causal_conv1d_update)
3742
from vllm.model_executor.layers.quantization import QuantizationConfig
3843
from vllm.model_executor.layers.vocab_parallel_embedding import (
3944
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -45,7 +50,8 @@
4550
from vllm.model_executor.models.mamba_cache import MambaCacheParams
4651
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
4752
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
48-
Qwen3NextSparseMoeBlock)
53+
Qwen3NextSparseMoeBlock,
54+
fused_gdn_gating)
4955
from vllm.model_executor.models.utils import (
5056
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5157
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
@@ -57,108 +63,6 @@
5763
from vllm.utils import direct_register_custom_op
5864
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
5965

60-
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
61-
causal_conv1d_update_npu)
62-
from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
63-
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
64-
65-
66-
def torch_chunk_gated_delta_rule(
67-
query,
68-
key,
69-
value,
70-
g,
71-
beta,
72-
chunk_size=64,
73-
initial_state=None,
74-
output_final_state=False,
75-
use_qk_l2norm_in_kernel=False,
76-
):
77-
initial_dtype = query.dtype
78-
if use_qk_l2norm_in_kernel:
79-
query = F.normalize(query, p=2, dim=-1)
80-
key = F.normalize(key, p=2, dim=-1)
81-
query, key, value, beta, g = [
82-
x.transpose(1, 2).contiguous().to(torch.float32)
83-
for x in (query, key, value, beta, g)
84-
]
85-
86-
batch_size, sequence_length, num_heads, k_head_dim = key.shape
87-
v_head_dim = value.shape[-1]
88-
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
89-
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
90-
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
91-
value = F.pad(value, (0, 0, 0, pad_size))
92-
beta = F.pad(beta, (0, pad_size))
93-
g = F.pad(g, (0, pad_size))
94-
tot_heads = num_heads + pad_size
95-
scale = 1 / (query.shape[-1]**0.5)
96-
query = query * scale
97-
98-
v_beta = value * beta.unsqueeze(-1)
99-
k_beta = key * beta.unsqueeze(-1)
100-
# reshape to chunks
101-
query, key, value, k_beta, v_beta = [
102-
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
103-
for x in (query, key, value, k_beta, v_beta)
104-
]
105-
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
106-
mask = torch.triu(torch.ones(chunk_size,
107-
chunk_size,
108-
dtype=torch.bool,
109-
device=query.device),
110-
diagonal=0)
111-
112-
# chunk decay
113-
g = g.cumsum(dim=-1)
114-
decay_mask = ((g.unsqueeze(-1) -
115-
g.unsqueeze(-2)).tril().exp().float()).tril()
116-
attn = -(
117-
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
118-
for i in range(1, chunk_size):
119-
row = attn[..., i, :i].clone()
120-
sub = attn[..., :i, :i].clone()
121-
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
122-
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
123-
value = attn @ v_beta
124-
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
125-
126-
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
127-
k_head_dim, v_head_dim).to(value) if
128-
initial_state is None else initial_state.to(value))
129-
130-
core_attn_out = torch.zeros_like(value)
131-
mask = torch.triu(torch.ones(chunk_size,
132-
chunk_size,
133-
dtype=torch.bool,
134-
device=query.device),
135-
diagonal=1)
136-
137-
# for each chunk
138-
for i in range(0, tot_heads // chunk_size):
139-
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
140-
attn = (q_i @ k_i.transpose(-1, -2) *
141-
decay_mask[:, :, i]).masked_fill_(mask, 0)
142-
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
143-
v_new = v_i - v_prime
144-
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
145-
core_attn_out[:, :, i] = attn_inter + attn @ v_new
146-
last_recurrent_state = (
147-
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
148-
(k_i *
149-
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
150-
-1, -2) @ v_new)
151-
152-
if not output_final_state:
153-
last_recurrent_state = None
154-
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
155-
core_attn_out.shape[1], -1,
156-
core_attn_out.shape[-1])
157-
core_attn_out = core_attn_out[:, :, :num_heads]
158-
core_attn_out = core_attn_out.transpose(1,
159-
2).contiguous().to(initial_dtype)
160-
return core_attn_out, last_recurrent_state
161-
16266

16367
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
16468

@@ -275,6 +179,8 @@ def __init__(
275179
self.norm = RMSNormGated(
276180
self.head_v_dim,
277181
eps=self.layer_norm_epsilon,
182+
norm_before_gate=True,
183+
device="npu",
278184
)
279185

280186
self.out_proj = RowParallelLinear(self.value_dim,
@@ -467,7 +373,7 @@ def _forward(
467373
query_start_loc=non_spec_query_start_loc,
468374
).transpose(0, 1)
469375
elif attn_metadata.num_decodes > 0:
470-
mixed_qkv_non_spec = causal_conv1d_update_npu(
376+
mixed_qkv_non_spec = causal_conv1d_update(
471377
mixed_qkv_non_spec,
472378
conv_state,
473379
conv_weights,
@@ -551,7 +457,7 @@ def _forward(
551457
(
552458
cur_core_attn_out_non_spec,
553459
cur_last_recurrent_state,
554-
) = torch_chunk_gated_delta_rule(
460+
) = chunk_gated_delta_rule(
555461
query=cur_q,
556462
key=cur_k,
557463
value=cur_v,

0 commit comments

Comments
 (0)