Skip to content

Commit 755ed7b

Browse files
[Misc] Simplify PoolerOutput and move to v1/outputs (#25629)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent a676e66 commit 755ed7b

File tree

6 files changed

+34
-82
lines changed

6 files changed

+34
-82
lines changed

vllm/executor/executor_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
1616
from vllm.logger import init_logger
1717
from vllm.lora.request import LoRARequest
18-
from vllm.sequence import ExecuteModelRequest, PoolerOutput
18+
from vllm.sequence import ExecuteModelRequest
1919
from vllm.tasks import SupportedTask
2020
from vllm.utils import make_async
21-
from vllm.v1.outputs import SamplerOutput
21+
from vllm.v1.outputs import PoolerOutput, SamplerOutput
2222
from vllm.worker.worker_base import WorkerBase
2323

2424
logger = init_logger(__name__)

vllm/model_executor/layers/pooler.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from vllm.logger import init_logger
1717
from vllm.model_executor.models.adapters import _load_st_projector
1818
from vllm.pooling_params import PoolingParams
19-
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
2019
from vllm.tasks import PoolingTask
21-
from vllm.utils import current_stream, resolve_obj_by_qualname
20+
from vllm.utils import resolve_obj_by_qualname
21+
from vllm.v1.outputs import PoolerOutput
2222
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
2323

2424
logger = init_logger(__name__)
@@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
190190
return PoolerClassify()
191191

192192

193-
def build_output(
194-
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
195-
# Pooling models D2H & synchronize occurs here
196-
if isinstance(all_data, list):
197-
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
198-
else:
199-
all_data = all_data.to("cpu", non_blocking=True)
200-
current_stream().synchronize()
201-
202-
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
203-
return PoolerOutput(outputs=all_outputs)
204-
205-
206193
class PoolingMethod(nn.Module, ABC):
207194

208195
@staticmethod
@@ -556,7 +543,7 @@ def forward(
556543
) -> PoolerOutput:
557544
pooled_data = self.pooling(hidden_states, pooling_metadata)
558545
pooled_data = self.head(pooled_data, pooling_metadata)
559-
return build_output(pooled_data)
546+
return pooled_data
560547

561548

562549
class StepPooler(Pooler):
@@ -607,7 +594,7 @@ def forward(
607594
) -> PoolerOutput:
608595
pooled_data = self.extract_states(hidden_states, pooling_metadata)
609596
pooled_data = self.head(pooled_data, pooling_metadata)
610-
return build_output(pooled_data)
597+
return pooled_data
611598

612599

613600
class ClassifierPooler(Pooler):
@@ -678,7 +665,7 @@ def forward(
678665
]
679666

680667
# scores shape: [batchsize, num_labels]
681-
return build_output(scores)
668+
return scores
682669

683670

684671
class DispatchPooler(Pooler):
@@ -708,7 +695,7 @@ def forward(
708695
) -> PoolerOutput:
709696
poolers_by_task = self.poolers_by_task
710697

711-
outputs = list[PoolingSequenceGroupOutput]()
698+
outputs = list[torch.Tensor]()
712699
offset = 0
713700
for task, group in groupby(get_tasks(pooling_metadata)):
714701
if not (pooler := poolers_by_task.get(task)):
@@ -722,10 +709,10 @@ def forward(
722709
pooling_metadata[offset:offset + num_items],
723710
)
724711

725-
outputs.extend(group_output.outputs)
712+
outputs.extend(group_output)
726713
offset += num_items
727714

728-
return PoolerOutput(outputs)
715+
return outputs
729716

730717
def extra_repr(self) -> str:
731718
s = f"supported_task={self.get_supported_tasks()}"

vllm/model_executor/models/gritlm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
1313
PoolerHead, PoolerNormalize,
1414
PoolingParamsUpdate,
15-
build_output, get_prompt_lens,
15+
get_prompt_lens,
1616
get_prompt_token_ids)
1717
from vllm.model_executor.models.llama import LlamaForCausalLM
18-
from vllm.sequence import PoolerOutput
1918
from vllm.tasks import PoolingTask
2019
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
20+
from vllm.v1.outputs import PoolerOutput
2121
from vllm.v1.pool.metadata import PoolingMetadata
2222

2323
from .interfaces_base import default_pooling_type
@@ -212,7 +212,7 @@ def forward(
212212
) -> PoolerOutput:
213213
pooled_data = self.pooling(hidden_states, pooling_metadata)
214214
pooled_data = self.head(pooled_data, pooling_metadata)
215-
return build_output(pooled_data)
215+
return pooled_data
216216

217217

218218
@default_pooling_type("MEAN")

vllm/sequence.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.v1.worker.kv_connector_model_runner_mixin import (
1212
KVConnectorOutput)
1313
else:
14-
LoRARequest = Any
1514
KVConnectorOutput = Any
1615

1716
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@@ -48,29 +47,6 @@ class RequestMetrics:
4847
model_execute_time: Optional[float] = None
4948

5049

51-
class PoolingSequenceGroupOutput(
52-
msgspec.Struct,
53-
omit_defaults=True, # type: ignore[call-arg]
54-
array_like=True, # type: ignore[call-arg]
55-
):
56-
"""The model output associated with a pooling sequence group."""
57-
# Annotated as Any to be compatible with msgspec
58-
# The actual type is in SequenceGroup.pooled_data
59-
data: Any
60-
61-
def get_data_nbytes(self) -> int:
62-
data: torch.Tensor = self.data
63-
return data.nbytes
64-
65-
def __repr__(self) -> str:
66-
return f"PoolingSequenceGroupOutput(data={self.data}"
67-
68-
def __eq__(self, other: object) -> bool:
69-
if not isinstance(other, PoolingSequenceGroupOutput):
70-
raise NotImplementedError()
71-
return self.data == other.data
72-
73-
7450
# cannot use msgspec.Struct here because Dynamo does not support it
7551
@dataclass
7652
class IntermediateTensors:
@@ -119,30 +95,6 @@ def __repr__(self) -> str:
11995
return f"IntermediateTensors(tensors={self.tensors})"
12096

12197

122-
class PoolerOutput(
123-
msgspec.Struct,
124-
omit_defaults=True, # type: ignore[call-arg]
125-
array_like=True): # type: ignore[call-arg]
126-
"""The output from a pooling operation in the pooling model."""
127-
outputs: list[PoolingSequenceGroupOutput]
128-
129-
def get_data_nbytes(self) -> int:
130-
return sum(o.get_data_nbytes() for o in self.outputs)
131-
132-
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
133-
return self.outputs[idx]
134-
135-
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
136-
self.outputs[idx] = value
137-
138-
def __len__(self):
139-
return len(self.outputs)
140-
141-
def __eq__(self, other: object):
142-
return isinstance(other,
143-
self.__class__) and self.outputs == other.outputs
144-
145-
14698
class ExecuteModelRequest(
14799
msgspec.Struct,
148100
array_like=True, # type: ignore[call-arg]

vllm/v1/outputs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from abc import ABC, abstractmethod
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, NamedTuple, Optional
6+
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
77

88
import torch
99

@@ -65,6 +65,11 @@ def empty_cpu(num_positions: int,
6565
)
6666

6767

68+
# [num_reqs, <dynamic>]
69+
# The shape of each element depends on the pooler used
70+
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]
71+
72+
6873
@dataclass
6974
class SamplerOutput:
7075

vllm/v1/worker/gpu_model_runner.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@
5252
from vllm.multimodal.utils import group_mm_kwargs_by_modality
5353
from vllm.pooling_params import PoolingParams
5454
from vllm.sampling_params import SamplingType
55-
from vllm.sequence import IntermediateTensors, PoolerOutput
55+
from vllm.sequence import IntermediateTensors
5656
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
5757
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
5858
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
5959
is_pin_memory_available,
6060
length_from_prompt_token_ids_or_embeds, round_up,
6161
supports_dynamo)
62+
from vllm.utils.jsontree import json_map_leaves
6263
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
6364
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
6465
from vllm.v1.attention.backends.utils import (
@@ -79,7 +80,7 @@
7980
# yapf: enable
8081
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
8182
DraftTokenIds, LogprobsLists, LogprobsTensors,
82-
ModelRunnerOutput, SamplerOutput)
83+
ModelRunnerOutput, PoolerOutput, SamplerOutput)
8384
from vllm.v1.pool.metadata import PoolingMetadata
8485
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
8586
from vllm.v1.sample.metadata import SamplingMetadata
@@ -1823,15 +1824,22 @@ def _pool(
18231824
device=hidden_states.device)
18241825
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
18251826

1826-
# Pooling models D2H & synchronize occurs in pooler.py:build_output
1827-
raw_pooler_output = self.model.pooler(
1828-
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
1827+
model = cast(VllmModelForPooling, self.model)
1828+
raw_pooler_output: PoolerOutput = model.pooler(
1829+
hidden_states=hidden_states,
1830+
pooling_metadata=pooling_metadata,
1831+
)
1832+
raw_pooler_output = json_map_leaves(
1833+
lambda x: x.to("cpu", non_blocking=True),
1834+
raw_pooler_output,
1835+
)
1836+
self._sync_device()
18291837

18301838
pooler_output: list[Optional[torch.Tensor]] = []
18311839
for raw_output, seq_len, prompt_len in zip(
18321840
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
18331841

1834-
output = raw_output.data if seq_len == prompt_len else None
1842+
output = raw_output if seq_len == prompt_len else None
18351843
pooler_output.append(output)
18361844

18371845
return ModelRunnerOutput(
@@ -3233,7 +3241,7 @@ def _dummy_pooler_run(
32333241
for task in self.get_supported_pooling_tasks():
32343242
# Run a full batch with each task to ensure none of them OOMs
32353243
output = self._dummy_pooler_run_task(hidden_states, task)
3236-
output_size[task] = output.get_data_nbytes()
3244+
output_size[task] = sum(o.nbytes for o in output)
32373245
del output # Allow GC
32383246

32393247
max_task = max(output_size.items(), key=lambda x: x[1])[0]

0 commit comments

Comments
 (0)