diff --git a/AI_Agents_Guide/Constrained_Decoding/README.md b/AI_Agents_Guide/Constrained_Decoding/README.md index 82e2e00e..6ce2c5a5 100644 --- a/AI_Agents_Guide/Constrained_Decoding/README.md +++ b/AI_Agents_Guide/Constrained_Decoding/README.md @@ -264,15 +264,17 @@ to serve your TensorRT-LLM model.), custom logits processor should be specified during model's initialization as a part of [Executor's](https://nvidia.github.io/TensorRT-LLM/executor.html#executor-api) configuration -([`logits_post_processor_map`](https://github.com/NVIDIA/TensorRT-LLM/blob/32ed92e4491baf2d54682a21d247e1948cca996e/tensorrt_llm/hlapi/llm_utils.py#L205)). +([`logits_post_processor_map`](https://github.com/NVIDIA/TensorRT-LLM/blob/258c7540c03517def55d9a5aadfa9288af474e1b/tensorrt_llm/llmapi/llm_utils.py#L322)). Below is the sample for reference. ```diff ... -+ executor_config.logits_post_processor_map = { -+ "": custom_logits_processor -+ } ++ logits_proc_config = trtllm.LogitsPostProcessorConfig() ++ logits_proc_config.processor_map = { ++ "": custom_logits_processor ++ } ++ executor_config.logits_post_processor_config = logits_proc_config self.executor = trtllm.Executor(model_path=..., model_type=..., executor_config=executor_config) @@ -331,17 +333,15 @@ def execute(self, requests): ... for request in requests: - response_sender = request.get_response_sender() - if get_input_scalar_by_name(request, 'stop'): - self.handle_stop_request(request.request_id(), response_sender) - else: + ... try: - converted = convert_request(request, - self.exclude_input_from_output, - self.decoupled) + converted_reqs = convert_request( + request, self.exclude_input_from_output, + self.decoupled) + logits_post_processor_name = get_input_tensor_by_name(request, 'logits_post_processor_name') + if logits_post_processor_name is not None: -+ converted.logits_post_processor_name = logits_post_processor_name.item().decode('utf-8') ++ for converted in converted_reqs: ++ converted.logits_post_processor_name = logits_post_processor_name.item().decode('utf-8') except Exception as e: ... ``` @@ -470,6 +470,10 @@ class TritonPythonModel: def get_executor_config(self, model_config): + tokenizer_dir = model_config['parameters']['tokenizer_dir']['string_value'] + logits_processor = LMFELogitsProcessor(tokenizer_dir, AnswerFormat.model_json_schema()) ++ logits_proc_config = trtllm.LogitsPostProcessorConfig() ++ logits_proc_config.processor_map = { ++ LMFELogitsProcessor.PROCESSOR_NAME: logits_processor ++ } kwargs = { "max_beam_width": get_parameter(model_config, "max_beam_width", int), @@ -490,9 +494,7 @@ class TritonPythonModel: self.get_peft_cache_config(model_config), "decoding_config": self.get_decoding_config(model_config), -+ "logits_post_processor_map":{ -+ LMFELogitsProcessor.PROCESSOR_NAME: logits_processor -+ } ++ "logits_post_processor_config": logits_proc_config } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExecutorConfig(**kwargs) @@ -603,6 +605,10 @@ class TritonPythonModel: def get_executor_config(self, model_config): + tokenizer_dir = model_config['parameters']['tokenizer_dir']['string_value'] + logits_processor = OutlinesLogitsProcessor(tokenizer_dir, AnswerFormat.model_json_schema()) ++ logits_proc_config = trtllm.LogitsPostProcessorConfig() ++ logits_proc_config.processor_map = { ++ OutlinesLogitsProcessor.PROCESSOR_NAME: logits_processor ++ } kwargs = { "max_beam_width": get_parameter(model_config, "max_beam_width", int), @@ -623,9 +629,7 @@ class TritonPythonModel: self.get_peft_cache_config(model_config), "decoding_config": self.get_decoding_config(model_config), -+ "logits_post_processor_map":{ -+ OutlinesLogitsProcessor.PROCESSOR_NAME: logits_processor -+ } ++ "logits_post_processor_config": logits_proc_config } kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.ExecutorConfig(**kwargs) diff --git a/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py b/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py index 70ef4237..cb21eab6 100644 --- a/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py +++ b/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py @@ -25,15 +25,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -from collections import defaultdict -from typing import DefaultDict, Dict, List +from typing import Any, Dict, List, Optional import torch from lmformatenforcer import JsonSchemaParser, TokenEnforcer from lmformatenforcer.integrations.trtllm import build_trtlmm_tokenizer_data -from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema -from outlines.integrations.utils import adapt_tokenizer +from outlines.fsm.guide import Guide, RegexGuide +from outlines.models.vllm import adapt_tokenizer +from outlines_core.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import AutoTokenizer @@ -103,6 +102,7 @@ def __call__( logits: torch.Tensor, ids: List[List[int]], stream_ptr: int, + client_id: Optional[int] ): # Create a mask with negative infinity to block all tokens initially. mask = torch.full_like(logits, fill_value=float("-inf"), device=logits.device) @@ -127,8 +127,8 @@ def __init__(self, tokenizer_dir, schema): ) tokenizer = adapt_tokenizer(tokenizer) regex_string = build_regex_from_schema(json.dumps(schema)) - self.fsm = RegexGuide(regex_string, tokenizer) - self._fsm_state: DefaultDict[int, int] = defaultdict(int) + self.guide: Guide = RegexGuide.from_regex(regex_string, tokenizer) + self._guide_states: Dict[int, Any] = {} self.mask_cache: Dict[int, torch.Tensor] = {} # By default, TensorRT-LLM includes request query into the output. # Outlines should only look at generated outputs, thus we'll keep @@ -141,6 +141,7 @@ def __call__( logits: torch.Tensor, ids: List[List[int]], stream_ptr: int, + client_id: Optional[int] ): seq_id = None # If the prefix token IDs have changed we assume that we are dealing @@ -151,9 +152,9 @@ def __call__( # processed or len(ids[0][len(self._prefix) :]) == 0 ): - self._fsm_state = defaultdict(int) - self._prefix = ids[0] seq_id = hash(tuple([])) + self._guide_states = {seq_id: self.guide.initial_state} + self._prefix = ids[0] else: # Remove the prefix token IDs from the input token IDs, @@ -162,14 +163,14 @@ def __call__( last_token = ids[-1] last_seq_id = hash(tuple(ids[:-1])) seq_id = hash(tuple(ids)) - self._fsm_state[seq_id] = self.fsm.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token + self._guide_states[seq_id] = self.guide.get_next_state( + state=self._guide_states[last_seq_id], token_id=last_token ) - state_id = self._fsm_state[seq_id] + state_id = self._guide_states[seq_id] if state_id not in self.mask_cache: - allowed_tokens = self.fsm.get_next_instruction( - state=self._fsm_state[seq_id] + allowed_tokens = self.guide.get_next_instruction( + state=self._guide_states[seq_id] ).tokens # Create a mask with negative infinity to block all # tokens initially.