Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f352b0d
upstream
ywang96 Sep 12, 2025
667e973
fix & add co-author
ywang96 Sep 12, 2025
8fe70a7
Update tests/models/registry.py
ywang96 Sep 12, 2025
5f6afa1
add missing processor test
Isotr0py Sep 12, 2025
9c5808f
revert str
ywang96 Sep 12, 2025
c7fe668
fix processor test hashing
Isotr0py Sep 12, 2025
aa2330f
fix frames indices
Isotr0py Sep 12, 2025
e1f8397
fix hit_rate 1.0
Isotr0py Sep 12, 2025
9c7939c
typo
Isotr0py Sep 12, 2025
5d4f6dd
fix placeholder replacement
Isotr0py Sep 12, 2025
0d88363
fix video example
Isotr0py Sep 13, 2025
574884c
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 13, 2025
6ec3968
fix vit backend
ywang96 Sep 13, 2025
0f80a19
fix online serving metadata
Isotr0py Sep 13, 2025
ba54870
avoid hardcode fps=1
Isotr0py Sep 14, 2025
f7c37a9
oops fps=1
Isotr0py Sep 14, 2025
a6c5d7e
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 15, 2025
cbf6dee
Merge branch 'main' into upstream-qwen-3-vl
Isotr0py Sep 15, 2025
bec9e7e
catch up and fix processor test
Isotr0py Sep 15, 2025
5f3cf0e
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 16, 2025
5027f31
fix model path
ywang96 Sep 16, 2025
d0133e2
fix qwen_vl_utils compatibility
Isotr0py Sep 16, 2025
07e4f52
fix fps
ywang96 Sep 16, 2025
78afc5b
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 16, 2025
44f89b2
fix
ywang96 Sep 16, 2025
c7ea6f7
cleanup
ywang96 Sep 16, 2025
10bd983
do not modify metadata
ywang96 Sep 16, 2025
a4c0d34
fix example and online fps
Isotr0py Sep 16, 2025
e11895e
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 16, 2025
438f146
clarify
ywang96 Sep 16, 2025
52e2b2b
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 17, 2025
887a833
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
Expand Down
78 changes: 78 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
)


# Qwen3-VL-Dense
def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen3-VL-4B-Instruct"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompts = [
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# Qwen3-VL-MOE
def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompts = [
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# R-4B
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1645,6 +1719,8 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni,
"qwen3_vl": run_qwen3_vl,
"qwen3_vl_moe": run_qwen3_vl_moe,
"rvl": run_r_vl,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
Expand All @@ -1658,6 +1734,8 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"glm4_1v",
"glm4_5v",
"glm4_5v_fp8",
"qwen3_vl",
"qwen3_vl_moe",
]


Expand Down
35 changes: 34 additions & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
# Ensure video metadata is included
if "video" in mm_data:
# GLM4.1V doesn't support multiple videos
video = mm_data["video"]
num_frames = len(video)
mm_data["video"] = (video, {
Expand All @@ -44,6 +45,34 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
return mm_data


def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
Patch the multimodal data for Qwen3-VL model.
"""

def create_metadata(frames: np.ndarray):
num_frames = len(frames)
return {
"total_num_frames": num_frames,
"fps": 2.0,
"duration": num_frames / 2.0,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
"do_sample_frames": True,
}

# Ensure video metadata is included
if "video" in mm_data:
video = mm_data["video"]
if isinstance(video, list):
# multiple videos
mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
else:
# single video
mm_data["video"] = (video, create_metadata(video))
return mm_data


def _test_processing_correctness(
model_id_or_arch: str,
hit_rate: float,
Expand Down Expand Up @@ -182,8 +211,10 @@ def _test_processing_correctness(
}

MM_DATA_PATCHES = {
# GLM4.1V requires video metadata to be included in the input
# GLM4.1V and Qwen3-VL requires video metadata to be included in the input
"glm4v": glm4_1v_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
"qwen3_vl_moe": qwen3_vl_patch_mm_data,
}


Expand Down Expand Up @@ -326,6 +357,8 @@ def _test_processing_correctness_one(
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-3B",
"Qwen/Qwen3-VL-4B-Instruct",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
"YannQi/R-4B",
"Skywork/Skywork-R1V-38B",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
Expand Down
6 changes: 6 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,12 @@ def check_available_online(
max_model_len=4096),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"), # noqa: E501
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def get_rope(
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
)
else:
rotary_emb = RotaryEmbedding(
Expand Down
144 changes: 133 additions & 11 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,18 @@ def triton_mrope(
return q, k


def apply_interleaved_rope(x: torch.Tensor,
mrope_section: list[int]) -> torch.Tensor:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
return x_t


class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""

Expand All @@ -189,6 +201,7 @@ def __init__(
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_interleaved: Optional[bool] = False,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
Expand All @@ -198,6 +211,7 @@ def __init__(
base, is_neox_style, dtype)

self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2

Expand Down Expand Up @@ -225,17 +239,20 @@ def forward_native(
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section

cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
if self.mrope_interleaved:
cos = apply_interleaved_rope(cos, self.mrope_section)
sin = apply_interleaved_rope(sin, self.mrope_section)
else:
cos = torch.cat([
m[i] for i, m in enumerate(
cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i] for i, m in enumerate(
sin.split(self.mrope_section, dim=-1))
],
dim=-1)

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
Expand Down Expand Up @@ -265,6 +282,10 @@ def forward_cuda(
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None

if self.mrope_interleaved:
# TODO: add triton implementation to support mrope-interleaved
return self.forward_native(positions, query, key)

num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
Expand Down Expand Up @@ -388,6 +409,15 @@ def get_input_positions_tensor(
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
return cls._qwen3vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens,
Expand Down Expand Up @@ -526,6 +556,98 @@ def _glm4v_get_input_positions_tensor(
len(input_tokens)).item()
return llm_positions, mrope_position_delta

@classmethod
def _qwen3vl_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""

video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
for _ in range(t)]

image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size

input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []

st = 0
remain_images, remain_videos = image_nums, video_nums

image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video

llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

@classmethod
def _ernie_get_input_positions_tensor(
cls,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(self,
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
super().__init__()

config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

Expand Down
Loading