|
6 | 6 | from typing import Optional
|
7 | 7 |
|
8 | 8 | import torch
|
9 |
| -import torch.nn.functional as F |
10 | 9 | from einops import rearrange
|
11 | 10 | from torch import nn
|
12 | 11 | from transformers.activations import ACT2FN
|
|
19 | 18 | get_tensor_model_parallel_rank,
|
20 | 19 | get_tensor_model_parallel_world_size)
|
21 | 20 | from vllm.forward_context import ForwardContext, get_forward_context
|
| 21 | +from vllm.model_executor.layers.fla.ops import RMSNormGated |
| 22 | +from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule |
| 23 | +from vllm.model_executor.layers.fla.ops.fused_recurrent import \ |
| 24 | + fused_recurrent_gated_delta_rule |
22 | 25 | from vllm.model_executor.layers.fused_moe import FusedMoE
|
23 | 26 | # yapf conflicts with isort for this block
|
24 | 27 | # yapf: disable
|
|
34 | 37 | mamba_v2_sharded_weight_loader
|
35 | 38 | from vllm.model_executor.layers.mamba.mamba_utils import (
|
36 | 39 | MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
| 40 | +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( |
| 41 | + causal_conv1d_fn, causal_conv1d_update) |
37 | 42 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
38 | 43 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
39 | 44 | DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
45 | 50 | from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
46 | 51 | from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
47 | 52 | from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
|
48 |
| - Qwen3NextSparseMoeBlock) |
| 53 | + Qwen3NextSparseMoeBlock, |
| 54 | + fused_gdn_gating) |
49 | 55 | from vllm.model_executor.models.utils import (
|
50 | 56 | AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
51 | 57 | is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
|
57 | 63 | from vllm.utils import direct_register_custom_op
|
58 | 64 | from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
59 | 65 |
|
60 |
| -from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, |
61 |
| - causal_conv1d_update_npu) |
62 |
| -from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating |
63 |
| -from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule |
64 |
| - |
65 |
| - |
66 |
| -def torch_chunk_gated_delta_rule( |
67 |
| - query, |
68 |
| - key, |
69 |
| - value, |
70 |
| - g, |
71 |
| - beta, |
72 |
| - chunk_size=64, |
73 |
| - initial_state=None, |
74 |
| - output_final_state=False, |
75 |
| - use_qk_l2norm_in_kernel=False, |
76 |
| -): |
77 |
| - initial_dtype = query.dtype |
78 |
| - if use_qk_l2norm_in_kernel: |
79 |
| - query = F.normalize(query, p=2, dim=-1) |
80 |
| - key = F.normalize(key, p=2, dim=-1) |
81 |
| - query, key, value, beta, g = [ |
82 |
| - x.transpose(1, 2).contiguous().to(torch.float32) |
83 |
| - for x in (query, key, value, beta, g) |
84 |
| - ] |
85 |
| - |
86 |
| - batch_size, sequence_length, num_heads, k_head_dim = key.shape |
87 |
| - v_head_dim = value.shape[-1] |
88 |
| - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size |
89 |
| - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) |
90 |
| - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) |
91 |
| - value = F.pad(value, (0, 0, 0, pad_size)) |
92 |
| - beta = F.pad(beta, (0, pad_size)) |
93 |
| - g = F.pad(g, (0, pad_size)) |
94 |
| - tot_heads = num_heads + pad_size |
95 |
| - scale = 1 / (query.shape[-1]**0.5) |
96 |
| - query = query * scale |
97 |
| - |
98 |
| - v_beta = value * beta.unsqueeze(-1) |
99 |
| - k_beta = key * beta.unsqueeze(-1) |
100 |
| - # reshape to chunks |
101 |
| - query, key, value, k_beta, v_beta = [ |
102 |
| - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) |
103 |
| - for x in (query, key, value, k_beta, v_beta) |
104 |
| - ] |
105 |
| - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) |
106 |
| - mask = torch.triu(torch.ones(chunk_size, |
107 |
| - chunk_size, |
108 |
| - dtype=torch.bool, |
109 |
| - device=query.device), |
110 |
| - diagonal=0) |
111 |
| - |
112 |
| - # chunk decay |
113 |
| - g = g.cumsum(dim=-1) |
114 |
| - decay_mask = ((g.unsqueeze(-1) - |
115 |
| - g.unsqueeze(-2)).tril().exp().float()).tril() |
116 |
| - attn = -( |
117 |
| - (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) |
118 |
| - for i in range(1, chunk_size): |
119 |
| - row = attn[..., i, :i].clone() |
120 |
| - sub = attn[..., :i, :i].clone() |
121 |
| - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) |
122 |
| - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) |
123 |
| - value = attn @ v_beta |
124 |
| - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) |
125 |
| - |
126 |
| - last_recurrent_state = (torch.zeros(batch_size, sequence_length, |
127 |
| - k_head_dim, v_head_dim).to(value) if |
128 |
| - initial_state is None else initial_state.to(value)) |
129 |
| - |
130 |
| - core_attn_out = torch.zeros_like(value) |
131 |
| - mask = torch.triu(torch.ones(chunk_size, |
132 |
| - chunk_size, |
133 |
| - dtype=torch.bool, |
134 |
| - device=query.device), |
135 |
| - diagonal=1) |
136 |
| - |
137 |
| - # for each chunk |
138 |
| - for i in range(0, tot_heads // chunk_size): |
139 |
| - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] |
140 |
| - attn = (q_i @ k_i.transpose(-1, -2) * |
141 |
| - decay_mask[:, :, i]).masked_fill_(mask, 0) |
142 |
| - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state |
143 |
| - v_new = v_i - v_prime |
144 |
| - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state |
145 |
| - core_attn_out[:, :, i] = attn_inter + attn @ v_new |
146 |
| - last_recurrent_state = ( |
147 |
| - last_recurrent_state * g[:, :, i, -1, None, None].exp() + |
148 |
| - (k_i * |
149 |
| - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( |
150 |
| - -1, -2) @ v_new) |
151 |
| - |
152 |
| - if not output_final_state: |
153 |
| - last_recurrent_state = None |
154 |
| - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], |
155 |
| - core_attn_out.shape[1], -1, |
156 |
| - core_attn_out.shape[-1]) |
157 |
| - core_attn_out = core_attn_out[:, :, :num_heads] |
158 |
| - core_attn_out = core_attn_out.transpose(1, |
159 |
| - 2).contiguous().to(initial_dtype) |
160 |
| - return core_attn_out, last_recurrent_state |
161 |
| - |
162 | 66 |
|
163 | 67 | class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
164 | 68 |
|
@@ -275,6 +179,8 @@ def __init__(
|
275 | 179 | self.norm = RMSNormGated(
|
276 | 180 | self.head_v_dim,
|
277 | 181 | eps=self.layer_norm_epsilon,
|
| 182 | + norm_before_gate=True, |
| 183 | + device="npu", |
278 | 184 | )
|
279 | 185 |
|
280 | 186 | self.out_proj = RowParallelLinear(self.value_dim,
|
@@ -467,7 +373,7 @@ def _forward(
|
467 | 373 | query_start_loc=non_spec_query_start_loc,
|
468 | 374 | ).transpose(0, 1)
|
469 | 375 | elif attn_metadata.num_decodes > 0:
|
470 |
| - mixed_qkv_non_spec = causal_conv1d_update_npu( |
| 376 | + mixed_qkv_non_spec = causal_conv1d_update( |
471 | 377 | mixed_qkv_non_spec,
|
472 | 378 | conv_state,
|
473 | 379 | conv_weights,
|
@@ -551,7 +457,7 @@ def _forward(
|
551 | 457 | (
|
552 | 458 | cur_core_attn_out_non_spec,
|
553 | 459 | cur_last_recurrent_state,
|
554 |
| - ) = torch_chunk_gated_delta_rule( |
| 460 | + ) = chunk_gated_delta_rule( |
555 | 461 | query=cur_q,
|
556 | 462 | key=cur_k,
|
557 | 463 | value=cur_v,
|
|
0 commit comments