Skip to content

Commit 25fda79

Browse files
refactor parsers
1 parent b6024b7 commit 25fda79

File tree

9 files changed

+267
-255
lines changed

9 files changed

+267
-255
lines changed

rllm/parser/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1-
from rllm.parser.tool_parser.qwen_tool_parser import QwenToolParser
2-
from rllm.parser.tool_parser.r1_tool_parser import R1ToolParser
3-
from rllm.parser.tool_parser.tool_parser_base import ToolParser
1+
from rllm.parser.chat_template_parser import ChatTemplateParser, DeepseekQwenChatTemplateParser, LlamaChatTemplateParser, QwenChatTemplateParser
2+
from rllm.parser.tool_parser import QwenToolParser, R1ToolParser, ToolParser
3+
4+
__all__ = [
5+
"ChatTemplateParser",
6+
"DeepseekQwenChatTemplateParser",
7+
"QwenChatTemplateParser",
8+
"LlamaChatTemplateParser",
9+
"ToolParser",
10+
"R1ToolParser",
11+
"QwenToolParser",
12+
]
413

514
PARSER_REGISTRY = {
615
"r1": R1ToolParser,
@@ -11,12 +20,3 @@
1120
def get_tool_parser(parser_name: str) -> type[ToolParser]:
1221
assert parser_name in PARSER_REGISTRY, f"Tool parser {parser_name} not found in {PARSER_REGISTRY}"
1322
return PARSER_REGISTRY[parser_name]
14-
15-
16-
__all__ = [
17-
"R1ToolParser",
18-
"QwenToolParser",
19-
"ToolParser",
20-
"get_tool_parser",
21-
"PARSER_REGISTRY",
22-
]

rllm/parser/chat_template/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

rllm/parser/tool_parser.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import json
2+
from abc import ABC, abstractmethod
3+
from typing import Any
4+
5+
from rllm.tools.tool_base import ToolCall
6+
7+
8+
class ToolParser(ABC):
9+
@abstractmethod
10+
def parse(self, model_response: str) -> list[ToolCall]:
11+
"""Extract tool calls from the model response."""
12+
raise NotImplementedError("Subclasses must implement this method")
13+
14+
@abstractmethod
15+
def get_tool_prompt(self, tools_schema: str) -> str:
16+
"""Get the tool prompt for the model."""
17+
raise NotImplementedError("Subclasses must implement this method")
18+
19+
@classmethod
20+
def get_parser(cls, tokenizer) -> "ToolParser":
21+
"""Factory method to get the appropriate tool parser based on a string identifier.
22+
23+
Args:
24+
tokenizer: The tokenizer to use with the parser
25+
26+
Returns:
27+
ToolParser: An instance of the requested parser
28+
29+
Raises:
30+
ValueError: If the parser_type is not recognized
31+
"""
32+
# Determine parser type based on tokenizer name or path
33+
if isinstance(tokenizer.name_or_path, str):
34+
model_name = tokenizer.name_or_path.lower()
35+
tokenizer_cls = tokenizer.__class__.__name__.lower()
36+
print(f"model_name: {model_name}, tokenizer_cls: {tokenizer_cls}")
37+
if any(x in model_name for x in ("deepseek", "deepscaler", "deepcoder")) and "llama" in tokenizer_cls:
38+
print(f"Using R1ToolParser for {tokenizer.name_or_path}")
39+
return R1ToolParser()
40+
elif "qwen" in model_name or "r2e" in model_name or "deepswe" in model_name or "qwen" in tokenizer_cls:
41+
print(f"Using QwenToolParser for {tokenizer.name_or_path}")
42+
return QwenToolParser()
43+
# TODO: add verfication to check equivalence of the parser with that from HuggingFace
44+
raise ValueError(f"No tool parser found for {tokenizer.name_or_path}")
45+
46+
47+
class R1ToolParser(ToolParser):
48+
"""Parser for R1 tool call format."""
49+
50+
def __init__(self):
51+
"""Initialize the R1 tool parser.
52+
53+
Args:
54+
model (str): Model name for tokenizer (optional)
55+
tokenizer: Pre-initialized tokenizer (optional)
56+
"""
57+
self.tool_calls_begin = "<|tool▁calls▁begin|>"
58+
self.tool_calls_end = "<|tool▁calls▁end|>"
59+
self.tool_call_begin = "<|tool▁call▁begin|>"
60+
self.tool_call_end = "<|tool▁call▁end|>"
61+
self.tool_sep = "<|tool▁sep|>"
62+
63+
def parse(self, model_response: str) -> list[ToolCall]:
64+
"""Parse tool calls from model output.
65+
66+
Args:
67+
model_output (str): Text containing tool calls
68+
69+
Returns:
70+
ToolInputs: Parsed tool calls
71+
"""
72+
tool_calls_dicts = self.parse_r1_tool_calls(model_response)
73+
74+
# Convert dictionaries to ToolCall objects
75+
tool_calls = [ToolCall(name=tc["name"], arguments=tc["arguments"]) for tc in tool_calls_dicts]
76+
return tool_calls
77+
78+
def parse_r1_tool_calls(self, text: str) -> list[dict]:
79+
"""Parse tool calls from text using the R1 special token format.
80+
81+
Format:
82+
<|tool▁calls▁begin|>
83+
<|tool▁call▁begin|>function<|tool▁sep|>function_name
84+
```json
85+
{"param": "value"}
86+
```
87+
<|tool▁call▁end|>
88+
// Additional tool calls follow the same format
89+
<|tool▁calls▁end|>
90+
91+
Returns:
92+
list[dict]: List of parsed tool calls, each containing 'name' and 'parameters'
93+
"""
94+
tool_calls = []
95+
96+
# Look for individual tool calls
97+
call_idx = 0
98+
while True:
99+
# Find the next tool call beginning
100+
call_idx = text.find(self.tool_call_begin, call_idx)
101+
if call_idx == -1:
102+
break
103+
104+
# Find the end of this tool call
105+
call_start = call_idx + len(self.tool_call_begin)
106+
call_end = text.find(self.tool_call_end, call_start)
107+
if call_end == -1:
108+
break
109+
110+
# Extract the content of this tool call
111+
call_content = text[call_start:call_end].strip()
112+
113+
# Parse function name
114+
func_prefix = "function" + self.tool_sep
115+
func_start = call_content.find(func_prefix)
116+
117+
if func_start != -1:
118+
# Extract function name after the prefix up to the next newline
119+
func_name_start = func_start + len(func_prefix)
120+
func_name_end = call_content.find("\n", func_name_start)
121+
122+
if func_name_end == -1:
123+
function_name = call_content[func_name_start:].strip()
124+
else:
125+
function_name = call_content[func_name_start:func_name_end].strip()
126+
else:
127+
# If function prefix not found, skip this call
128+
call_idx = call_end + len(self.tool_call_end)
129+
continue
130+
131+
# Extract JSON arguments
132+
json_start = call_content.find("```json\n")
133+
if json_start == -1:
134+
json_start = call_content.find("```json")
135+
if json_start == -1:
136+
call_idx = call_end + len(self.tool_call_end)
137+
continue
138+
json_start += len("```json")
139+
else:
140+
json_start += len("```json\n")
141+
142+
json_end = call_content.find("```", json_start)
143+
if json_end == -1:
144+
call_idx = call_end + len(self.tool_call_end)
145+
continue
146+
147+
args_str = call_content[json_start:json_end].strip()
148+
149+
try:
150+
args_json = json.loads(args_str)
151+
except json.JSONDecodeError:
152+
call_idx = call_end + len(self.tool_call_end)
153+
continue
154+
155+
# Add this tool call to our list
156+
tool_calls.append({"name": function_name, "arguments": args_json})
157+
158+
# Move past this call for the next iteration
159+
call_idx = call_end + len(self.tool_call_end)
160+
161+
return tool_calls
162+
163+
def get_tool_prompt(self, tools_schema: str) -> str:
164+
return f"""
165+
# Tools
166+
167+
You may call one or more functions to assist with the user query.
168+
<tools>
169+
{tools_schema}
170+
</tools>
171+
172+
For function call returns, you should first print <|tool▁calls▁begin|>
173+
174+
For each function call, you should return object like:
175+
<|tool▁call▁begin|>function<|tool▁sep|><function_name>
176+
"""
177+
178+
179+
class QwenToolParser(ToolParser):
180+
def __init__(self):
181+
"""Initialize the parser with specified type and model.
182+
183+
Args:
184+
model (str): Model name for tokenizer (optional)
185+
parser_type (str): Type of parser to use ('qwen' or other parsers you might add)
186+
"""
187+
self.tool_call_begin = "<tool_call>"
188+
self.tool_call_end = "</tool_call>"
189+
self.tool_output_begin = "<tool_response>"
190+
self.tool_output_end = "</tool_response>"
191+
192+
def parse(self, model_response: str) -> list[ToolCall]:
193+
"""Parse tool calls from model output.
194+
195+
Args:
196+
model_output (str): Text containing tool calls
197+
198+
Returns:
199+
ToolInputs: Parsed tool calls
200+
"""
201+
tool_calls_dicts = self.parse_qwen_tool_calls(model_response)
202+
tool_calls = [ToolCall(name=tc["name"], arguments=tc["arguments"]) for tc in tool_calls_dicts]
203+
return tool_calls
204+
205+
def parse_qwen_tool_calls(self, text: str) -> list[dict[str, Any]]:
206+
"""Parse tool calls from text using a simple token format.
207+
208+
Format:
209+
<tool_call>{"name": "function_name", "arguments": {...}}</tool_call>
210+
211+
Returns:
212+
list[dict]: List of parsed tool calls, each containing 'name' and 'parameters'
213+
"""
214+
215+
tool_calls: list[dict[str, Any]] = []
216+
217+
# Return empty list if no tool calls found
218+
if self.tool_call_begin not in text:
219+
return tool_calls
220+
221+
# Process all tool calls in the text
222+
while self.tool_call_begin in text:
223+
start = text.find(self.tool_call_begin) + len(self.tool_call_begin)
224+
end = text.find(self.tool_call_end)
225+
if end == -1:
226+
break
227+
228+
# Extract and parse the JSON content
229+
json_content = text[start:end].strip()
230+
try:
231+
call_data = json.loads(json_content)
232+
# Convert to common format matching parse_tool_calls output
233+
tool_calls.append({"name": call_data["name"], "arguments": call_data["arguments"]})
234+
except json.JSONDecodeError:
235+
print(f"Error parsing tool call: {json_content}")
236+
text = text[end + len(self.tool_call_end) :]
237+
continue
238+
239+
# Move to next potential tool call
240+
text = text[end + len(self.tool_call_end) :]
241+
242+
return tool_calls
243+
244+
def get_tool_prompt(self, tools_schema: str) -> str:
245+
return f"""
246+
You are provided with function signatures within <tools></tools> XML tags:
247+
<tools>
248+
{tools_schema}
249+
</tools>
250+
251+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
252+
<tool_call>
253+
{{"name": <function-name>, "arguments": <args-json-object>}}
254+
</tool_call><|im_end|>
255+
"""

rllm/parser/tool_parser/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

rllm/parser/tool_parser/qwen_tool_parser.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)