|
52 | 52 | from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
53 | 53 | from vllm.pooling_params import PoolingParams
|
54 | 54 | from vllm.sampling_params import SamplingType
|
55 |
| -from vllm.sequence import IntermediateTensors, PoolerOutput |
| 55 | +from vllm.sequence import IntermediateTensors |
56 | 56 | from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
57 | 57 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
58 | 58 | GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
|
59 | 59 | is_pin_memory_available,
|
60 | 60 | length_from_prompt_token_ids_or_embeds, round_up,
|
61 | 61 | supports_dynamo)
|
| 62 | +from vllm.utils.jsontree import json_map_leaves |
62 | 63 | from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
63 | 64 | from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
64 | 65 | from vllm.v1.attention.backends.utils import (
|
|
79 | 80 | # yapf: enable
|
80 | 81 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
81 | 82 | DraftTokenIds, LogprobsLists, LogprobsTensors,
|
82 |
| - ModelRunnerOutput, SamplerOutput) |
| 83 | + ModelRunnerOutput, PoolerOutput, SamplerOutput) |
83 | 84 | from vllm.v1.pool.metadata import PoolingMetadata
|
84 | 85 | from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
85 | 86 | from vllm.v1.sample.metadata import SamplingMetadata
|
@@ -1823,15 +1824,22 @@ def _pool(
|
1823 | 1824 | device=hidden_states.device)
|
1824 | 1825 | seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
1825 | 1826 |
|
1826 |
| - # Pooling models D2H & synchronize occurs in pooler.py:build_output |
1827 |
| - raw_pooler_output = self.model.pooler( |
1828 |
| - hidden_states=hidden_states, pooling_metadata=pooling_metadata) |
| 1827 | + model = cast(VllmModelForPooling, self.model) |
| 1828 | + raw_pooler_output: PoolerOutput = model.pooler( |
| 1829 | + hidden_states=hidden_states, |
| 1830 | + pooling_metadata=pooling_metadata, |
| 1831 | + ) |
| 1832 | + raw_pooler_output = json_map_leaves( |
| 1833 | + lambda x: x.to("cpu", non_blocking=True), |
| 1834 | + raw_pooler_output, |
| 1835 | + ) |
| 1836 | + self._sync_device() |
1829 | 1837 |
|
1830 | 1838 | pooler_output: list[Optional[torch.Tensor]] = []
|
1831 | 1839 | for raw_output, seq_len, prompt_len in zip(
|
1832 | 1840 | raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
1833 | 1841 |
|
1834 |
| - output = raw_output.data if seq_len == prompt_len else None |
| 1842 | + output = raw_output if seq_len == prompt_len else None |
1835 | 1843 | pooler_output.append(output)
|
1836 | 1844 |
|
1837 | 1845 | return ModelRunnerOutput(
|
@@ -3233,7 +3241,7 @@ def _dummy_pooler_run(
|
3233 | 3241 | for task in self.get_supported_pooling_tasks():
|
3234 | 3242 | # Run a full batch with each task to ensure none of them OOMs
|
3235 | 3243 | output = self._dummy_pooler_run_task(hidden_states, task)
|
3236 |
| - output_size[task] = output.get_data_nbytes() |
| 3244 | + output_size[task] = sum(o.nbytes for o in output) |
3237 | 3245 | del output # Allow GC
|
3238 | 3246 |
|
3239 | 3247 | max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
|
0 commit comments