|
64 | 64 | from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
65 | 65 | from vllm.pooling_params import PoolingParams
|
66 | 66 | from vllm.sampling_params import SamplingType
|
67 |
| -from vllm.sequence import IntermediateTensors, PoolerOutput |
| 67 | +from vllm.sequence import IntermediateTensors |
68 | 68 | from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
69 | 69 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
70 | 70 | LazyLoader, cdiv, get_dtype_size,
|
71 | 71 | is_pin_memory_available)
|
| 72 | +from vllm.utils.jsontree import json_map_leaves |
72 | 73 | from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
73 | 74 | from vllm.v1.attention.backends.utils import (
|
74 | 75 | AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
|
144 | 145 |
|
145 | 146 | if not vllm_version_is("0.10.2"):
|
146 | 147 | from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs
|
| 148 | + from vllm.v1.outputs import PoolerOutput |
147 | 149 | else:
|
| 150 | + from vllm.sequence import PoolerOutput |
148 | 151 | UniformTypeKVCacheSpecs = None
|
149 | 152 |
|
150 | 153 |
|
@@ -1806,18 +1809,30 @@ def _pool(
|
1806 | 1809 | device=hidden_states.device)
|
1807 | 1810 | seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
1808 | 1811 |
|
1809 |
| - # Pooling models D2H & synchronize occurs in pooler.py:build_output |
1810 |
| - raw_pooler_output = self.model.pooler( |
1811 |
| - hidden_states=hidden_states, pooling_metadata=pooling_metadata) |
| 1812 | + if vllm_version_is("0.10.2"): |
| 1813 | + # Pooling models D2H & synchronize occurs in pooler.py:build_output |
| 1814 | + raw_pooler_output = self.model.pooler( |
| 1815 | + hidden_states=hidden_states, pooling_metadata=pooling_metadata) |
| 1816 | + else: |
| 1817 | + model = cast(VllmModelForPooling, self.model) |
| 1818 | + raw_pooler_output = model.pooler( |
| 1819 | + hidden_states=hidden_states, |
| 1820 | + pooling_metadata=pooling_metadata, |
| 1821 | + ) |
| 1822 | + raw_pooler_output = json_map_leaves( |
| 1823 | + lambda x: x.to("cpu", non_blocking=True), |
| 1824 | + raw_pooler_output, |
| 1825 | + ) |
| 1826 | + torch.npu.synchronize() |
1812 | 1827 |
|
1813 | 1828 | pooler_output: list[Optional[torch.Tensor]] = []
|
1814 | 1829 | for raw_output, seq_len, prompt_len in zip(
|
1815 | 1830 | raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
1816 |
| - |
1817 |
| - if seq_len == prompt_len: |
1818 |
| - pooler_output.append(raw_output.data) |
| 1831 | + if vllm_version_is("0.10.2"): |
| 1832 | + output = raw_output.data if seq_len == prompt_len else None |
1819 | 1833 | else:
|
1820 |
| - pooler_output.append(None) |
| 1834 | + output = raw_output if seq_len == prompt_len else None |
| 1835 | + pooler_output.append(output) |
1821 | 1836 |
|
1822 | 1837 | return ModelRunnerOutput(
|
1823 | 1838 | req_ids=self.input_batch.req_ids,
|
@@ -2582,7 +2597,10 @@ def _dummy_pooler_run(
|
2582 | 2597 | for task in self.get_supported_pooling_tasks():
|
2583 | 2598 | # Run a full batch with each task to ensure none of them OOMs
|
2584 | 2599 | output = self._dummy_pooler_run_task(hidden_states, task)
|
2585 |
| - output_size[task] = output.get_data_nbytes() |
| 2600 | + if vllm_version_is("0.10.2"): |
| 2601 | + output_size[task] = output.get_data_nbytes() |
| 2602 | + else: |
| 2603 | + output_size[task] = sum(o.nbytes for o in output) |
2586 | 2604 | del output # Allow GC
|
2587 | 2605 |
|
2588 | 2606 | max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
|
0 commit comments