Skip to content

Commit d25b6f9

Browse files
committed
Fix sampling params
1 parent c8ee100 commit d25b6f9

File tree

4 files changed

+7
-37
lines changed

4 files changed

+7
-37
lines changed

src/engine.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from constants import DEFAULT_MAX_CONCURRENCY, DEFAULT_BATCH_SIZE, DEFAULT_BATCH_SIZE_GROWTH_FACTOR, DEFAULT_MIN_BATCH_SIZE
1616
from tokenizer import TokenizerWrapper
1717
from config import EngineConfig
18-
from sampling_params import validate_sampling_params
1918

2019
class vLLMEngine:
2120
def __init__(self, engine = None):
@@ -33,10 +32,9 @@ def dynamic_batch_size(self, current_batch_size, batch_size_growth_factor):
3332

3433
async def generate(self, job_input: JobInput):
3534
try:
36-
validated_sampling_params = validate_sampling_params(job_input.input_sampling_params)
3735
async for batch in self._generate_vllm(
3836
llm_input=job_input.llm_input,
39-
validated_sampling_params=validated_sampling_params,
37+
validated_sampling_params=job_input.sampling_params,
4038
batch_size=job_input.max_batch_size,
4139
stream=job_input.stream,
4240
apply_chat_template=job_input.apply_chat_template,

src/sampling_params.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict
44
from vllm.utils import random_uuid
55
from vllm.entrypoints.openai.protocol import ErrorResponse
6+
from vllm import SamplingParams
67

78
logging.basicConfig(level=logging.INFO)
89

@@ -31,7 +32,7 @@ def __init__(self, job):
3132
self.max_batch_size = job.get("max_batch_size")
3233
self.apply_chat_template = job.get("apply_chat_template", False)
3334
self.use_openai_format = job.get("use_openai_format", False)
34-
self.input_sampling_params = job.get("sampling_params", {})
35+
self.sampling_params = SamplingParams(**job.get("sampling_params", {}))
3536
self.request_id = random_uuid()
3637
batch_size_growth_factor = job.get("batch_size_growth_factor")
3738
self.batch_size_growth_factor = float(batch_size_growth_factor) if batch_size_growth_factor else None
@@ -62,4 +63,6 @@ def update(self):
6263
def create_error_response(message: str, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
6364
return ErrorResponse(message=message,
6465
type=err_type,
65-
code=status_code.value)
66+
code=status_code.value)
67+
68+

vllm-base-image/vllm

Submodule vllm updated from 9cb4acd to c46d230

0 commit comments

Comments
 (0)