Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference/pooling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py
```

## Multi vector retrieval usage

```bash
python examples/offline_inference/pooling/multi_vector_retrieval.py
```

## Named Entity Recognition (NER) usage

```bash
Expand Down
56 changes: 56 additions & 0 deletions examples/offline_inference/pooling/multi_vector_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser


def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-m3",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()


def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))

# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)

# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(len(embeds))

# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs = llm.encode(prompts, pooling_task="token_embed")

# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
multi_vector = output.outputs.data
print(multi_vector.shape)


if __name__ == "__main__":
args = parse_args()
main(args)
8 changes: 7 additions & 1 deletion examples/online_serving/pooling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@ python examples/online_serving/pooling/cohere_rerank_client.py
python examples/online_serving/pooling/jinaai_rerank_client.py
```

## Multi vector retrieval usage

```bash
python examples/online_serving/pooling/multi_vector_retrieval_client.py
```

## Named Entity Recognition (NER) usage

```bash
python examples/online_serving/pooling/ner.py
python examples/online_serving/pooling/ner_client.py
```

## Openai chat embedding for multimodal usage
Expand Down
54 changes: 54 additions & 0 deletions examples/online_serving/pooling/multi_vector_retrieval_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Example online usage of Pooling API for multi vector retrieval.

Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.

vllm serve BAAI/bge-m3
"""

import argparse

import requests
import torch


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-m3")

return parser.parse_args()


def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompt = {"model": model_name, "input": prompts}

pooling_response = post_http_request(prompt=prompt, api_url=api_url)
for output in pooling_response.json()["data"]:
multi_vector = torch.tensor(output["data"])
print(multi_vector.shape)


if __name__ == "__main__":
args = parse_args()
main(args)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def embed(self,
return [req_output.outputs.embedding for req_output in req_outputs]

def encode(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts)
req_outputs = self.llm.encode(prompts, pooling_task="encode")
return [req_output.outputs.data for req_output in req_outputs]

def reward(self, prompts: list[str]) -> list[list[float]]:
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/pooling/llm/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_outputs(activation):
), "w_activation should be close to activation(wo_activation)."


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
Expand Down
6 changes: 6 additions & 0 deletions tests/entrypoints/pooling/llm/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def llm():


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)


def test_pooling_params(llm: LLM):

def get_outputs(normalize):
Expand Down
1 change: 0 additions & 1 deletion tests/entrypoints/pooling/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def test_multiple_pooling_params(llm: LLM):
assert len(PROMPTS) == len(outputs)


@pytest.mark.skip_global_cleanup
def test_right_side_truncation(llm: LLM):
# Embeddings models should truncate the end of the prompt
tokenizer = llm.get_tokenizer()
Expand Down
29 changes: 15 additions & 14 deletions tests/entrypoints/pooling/llm/test_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ def llm():
cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM):

def get_outputs(softmax):
outputs = llm.reward(prompts,
pooling_params=PoolingParams(softmax=softmax),
use_tqdm=False)
def get_outputs(activation):
outputs = llm.reward(
prompts,
pooling_params=PoolingParams(activation=activation),
use_tqdm=False)
return torch.cat([x.outputs.data for x in outputs])

default = get_outputs(softmax=None)
w_softmax = get_outputs(softmax=True)
wo_softmax = get_outputs(softmax=False)
default = get_outputs(activation=None)
w_activation = get_outputs(activation=True)
wo_activation = get_outputs(activation=False)

assert torch.allclose(default, w_softmax,
atol=1e-2), "Default should use softmax."
assert not torch.allclose(w_softmax, wo_softmax,
atol=1e-2), "wo_softmax should not use softmax."
assert torch.allclose(default, w_activation,
atol=1e-2), "Default should use activation."
assert not torch.allclose(
w_activation, wo_activation,
atol=1e-2), "wo_activation should not use activation."
assert torch.allclose(
softmax(wo_softmax), w_softmax,
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."
softmax(wo_activation), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."
1 change: 0 additions & 1 deletion tests/entrypoints/pooling/llm/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def llm():
cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM):

def get_outputs(activation):
Expand Down
23 changes: 22 additions & 1 deletion tests/entrypoints/pooling/openai/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
run_embedding_correctness_test)
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.entrypoints.openai.protocol import EmbeddingResponse, PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "intfloat/multilingual-e5-small"
Expand Down Expand Up @@ -394,3 +394,24 @@ async def get_outputs(normalize):
assert torch.allclose(
w_normal, F.normalize(wo_normal, p=2, dim=-1),
atol=1e-2), "w_normal should be close to normal(wo_normal)."


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]

response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float"
},
)

poolings = PoolingResponse.model_validate(response.json())

assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 384
23 changes: 22 additions & 1 deletion tests/entrypoints/pooling/openai/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as F

from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import RerankResponse
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse

MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
Expand Down Expand Up @@ -156,3 +156,24 @@ async def get_outputs(activation):
assert torch.allclose(
F.sigmoid(wo_activation), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]

response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float"
},
)

poolings = PoolingResponse.model_validate(response.json())

assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 1
46 changes: 46 additions & 0 deletions tests/models/language/pooling/test_head_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForSequenceClassification


@pytest.mark.parametrize(
"model",
["nie3e/sentiment-polish-gpt2-small"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

for head_dtype_str in ["float32", "model"]:
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
hf_overrides={"head_dtype":
head_dtype_str}) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
dtype = model_config.head_dtype
head_dtype = model_config.head_dtype

if head_dtype_str == "float32":
assert head_dtype == torch.float32
elif head_dtype_str == "model":
assert head_dtype == dtype

vllm_outputs = vllm_model.classify(example_prompts)

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).float()
vllm_output = torch.tensor(vllm_output).float()

assert torch.allclose(hf_output, vllm_output, atol=1e-2)
46 changes: 46 additions & 0 deletions tests/models/language/pooling/test_multi_vector_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel

from tests.models.utils import check_embeddings_close


@pytest.mark.parametrize(
"model",
["BAAI/bge-m3"],
)
@pytest.mark.parametrize("dtype", ["half"])
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str):

with vllm_runner(
model,
runner="pooling",
max_model_len=None,
) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
embedding = output.last_hidden_state[0].float()
hf_outputs.append(embedding.cpu())

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
Loading