Skip to content

Commit 6e4a5f4

Browse files
authored
Add FileAgent (#1433)
* final * retrigger ci * bump version * lint * try all extras
1 parent 7a9c717 commit 6e4a5f4

File tree

15 files changed

+579
-88
lines changed

15 files changed

+579
-88
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ jobs:
9696

9797
- name: Install dependencies
9898
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
99-
run: poetry install --no-interaction --only main --extras security
99+
run: poetry install --no-interaction --all-extras
100100

101101
- name: AutoFix Patchwork
102102
run: |
@@ -151,7 +151,7 @@ jobs:
151151

152152
- name: Install dependencies
153153
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
154-
run: poetry install --no-interaction --only main
154+
run: poetry install --no-interaction --all-extras
155155

156156
- name: PR Review
157157
run: |

patchwork/common/multiturn_strategy/agentic_strategy_v2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def model_post_init(self, __context: Any) -> None:
4242

4343
class AgenticStrategyV2:
4444
def __init__(
45-
self,
46-
model: str,
47-
llm_client: LlmClient,
48-
template_data: dict[str, str],
49-
system_prompt_template: str,
50-
user_prompt_template: str,
51-
agent_configs: list[AgentConfig],
52-
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
53-
limit: Optional[int] = None,
45+
self,
46+
model: str,
47+
llm_client: LlmClient,
48+
template_data: dict[str, str],
49+
system_prompt_template: str,
50+
user_prompt_template: str,
51+
agent_configs: list[AgentConfig],
52+
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
53+
limit: Optional[int] = None,
5454
):
5555
self.__limit = limit
5656
self.__template_data = template_data
@@ -153,7 +153,7 @@ def execute(self, limit: Optional[int] = None) -> dict:
153153
self.__summariser.run(
154154
"Please give me the result from the following summary of what the assistants have done."
155155
+ agent_summary_list,
156-
)
156+
)
157157
)
158158
self.__request_tokens += final_result.usage().request_tokens or 0
159159
self.__response_tokens += final_result.usage().response_tokens or 0

patchwork/common/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from patchwork.common.tools.api_tool import APIRequestTool
12
from patchwork.common.tools.bash_tool import BashTool
23
from patchwork.common.tools.code_edit_tools import CodeEditTool, FileViewTool
34
from patchwork.common.tools.grep_tool import FindTextTool, FindTool
4-
from patchwork.common.tools.api_tool import APIRequestTool
55
from patchwork.common.tools.tool import Tool
66

77
__all__ = [

patchwork/common/tools/api_tool.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ def execute(
9191
status_code = response.status_code
9292
headers = response.headers
9393

94-
header_string = "\n".join(
95-
f"{key}: {value}" for key, value in headers.items()
96-
)
94+
header_string = "\n".join(f"{key}: {value}" for key, value in headers.items())
9795

9896
return (
9997
f"HTTP/{response.raw.version / 10:.1f} {status_code} {response.reason}\n"

patchwork/common/tools/code_edit_tools.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def json_schema(self) -> dict:
4545
}
4646

4747
def __get_abs_path(self, path: str):
48-
wanted_path = Path(path).resolve()
48+
wanted_path = Path(path)
49+
if not Path(path).is_absolute():
50+
wanted_path = self.repo_path / path
4951
if wanted_path.is_relative_to(self.repo_path):
5052
return wanted_path
5153
else:
@@ -57,13 +59,16 @@ def execute(self, path: str, view_range: Optional[list[int]] = None) -> str:
5759
return f"Error: Path {abs_path} does not exist"
5860

5961
if abs_path.is_file():
60-
with open(abs_path, "r") as f:
61-
content = f.read()
62-
63-
if view_range:
64-
lines = content.splitlines()
65-
start, end = view_range
66-
content = "\n".join(lines[start - 1 : end])
62+
try:
63+
with open(abs_path, "r") as f:
64+
content = f.read()
65+
66+
if view_range:
67+
lines = content.splitlines()
68+
start, end = view_range
69+
content = "\n".join(lines[start - 1 : end])
70+
except Exception as e:
71+
content = "Error: " + str(e)
6772

6873
if len(content) > self.__VIEW_LIMIT:
6974
content = content[: self.__VIEW_LIMIT] + self.__TRUNCATION_TOKEN
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from __future__ import annotations
2+
3+
import sqlite3
4+
import subprocess
5+
from pathlib import Path
6+
7+
import pandas
8+
from sqlalchemy import URL
9+
from typing_extensions import Optional
10+
11+
from patchwork.common.tools.tool import Tool
12+
13+
14+
class In2CSVTool(Tool, tool_name="in2csv_tool", auto_register=False):
15+
def __init__(self, path: str):
16+
super().__init__()
17+
self.path = path
18+
19+
@property
20+
def json_schema(self) -> dict:
21+
return {
22+
"name": "in2csv_tool",
23+
"description": """\
24+
Convert common tabular data formats to CSV.
25+
26+
optional arguments:
27+
--reset-dimensions Ignore the sheet dimensions provided by the XLSX file.
28+
--encoding-xls ENCODING_XLS
29+
Specify the encoding of the input XLS file.
30+
-y SNIFF_LIMIT, --snifflimit SNIFF_LIMIT
31+
Limit CSV dialect sniffing to the specified number of
32+
bytes. Specify "0" to disable sniffing entirely, or
33+
"-1" to sniff the entire file.
34+
-I, --no-inference Disable type inference (and --locale, --date-format,
35+
--datetime-format, --no-leading-zeroes) when parsing
36+
CSV input.
37+
""",
38+
"input_schema": {
39+
"type": "object",
40+
"properties": {
41+
"files": {
42+
"type": "array",
43+
"items": {"type": "string"},
44+
"description": "The CSV file(s) to operate on",
45+
},
46+
"args": {
47+
"type": "array",
48+
"items": {"type": "string"},
49+
"description": "The args to run with",
50+
},
51+
},
52+
"required": ["files"],
53+
},
54+
}
55+
56+
def execute(self, files: list[str], args: Optional[list[str]] = None) -> str:
57+
args = args or []
58+
59+
original_csvs = set()
60+
for p in Path(self.path).iterdir():
61+
if p.suffix == ".csv":
62+
original_csvs.add(p.name)
63+
64+
p = subprocess.run(
65+
["in2csv", *files, *args, "--write-sheets", "-", "--use-sheet-names"],
66+
capture_output=True,
67+
text=True,
68+
cwd=self.path,
69+
)
70+
if p.returncode != 0:
71+
return "ERROR:\n" + p.stderr
72+
73+
rv = "Files converted to CSV:"
74+
for p in Path(self.path).iterdir():
75+
if p.suffix == ".csv" and p.name not in original_csvs:
76+
rv += f"\n* {p}"
77+
78+
return rv
79+
80+
81+
class CSVSQLTool(Tool, tool_name="csvsql_tool", auto_register=False):
82+
def __init__(self, path: str, tmp_path: str):
83+
super().__init__()
84+
self.path = path
85+
self.tmp_path = tmp_path
86+
87+
@property
88+
def json_schema(self) -> dict:
89+
return {
90+
"name": "csvsql_tool",
91+
"description": """\
92+
Execute SQL query directly on csv files. The name of the csv files can be referenced as table in the SQL query
93+
94+
If the output is larger than 5000 characters, the remaining characters are replaced with <TRUNCATED>.
95+
""",
96+
"input_schema": {
97+
"type": "object",
98+
"properties": {
99+
"files": {
100+
"type": "array",
101+
"items": {"type": "string"},
102+
"description": "The CSV file(s) to operate on",
103+
},
104+
"query": {
105+
"type": "string",
106+
"description": "SQL query to execute",
107+
},
108+
},
109+
"required": ["files", "query"],
110+
},
111+
}
112+
113+
def execute(self, files: list[str], query: str) -> str:
114+
db_path = (Path(self.tmp_path) / "tmp.db").resolve()
115+
db_url = URL.create(drivername="sqlite", host="/" + str(db_path)).render_as_string()
116+
117+
files_to_insert = []
118+
if db_path.is_file():
119+
with sqlite3.connect(str(db_path)) as conn:
120+
for file in files:
121+
res = conn.execute(
122+
f"SELECT 1 from {file.removesuffix('.csv')}",
123+
)
124+
if res.fetchone() is None:
125+
files_to_insert.append(file)
126+
else:
127+
files_to_insert = files
128+
129+
if len(files_to_insert) > 0:
130+
p = subprocess.run(
131+
["csvsql", *files_to_insert, "--db", db_url, "--insert"], capture_output=True, text=True, cwd=self.path
132+
)
133+
if p.returncode != 0:
134+
return "ERROR:\n" + p.stderr
135+
136+
with sqlite3.connect(str(db_path)) as conn:
137+
pandas_df = pandas.read_sql_query(query, conn)
138+
rv = pandas_df.to_csv()
139+
140+
if len(rv) > 5000:
141+
return rv[:5000] + "<TRUNCATED>"
142+
return rv

patchwork/common/tools/grep_tool.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ def json_schema(self) -> dict:
158158
}
159159

160160
def execute(
161-
self,
162-
pattern: Optional[str] = None,
163-
path: Optional[Path] = None,
164-
is_case_sensitive: bool = False,
161+
self,
162+
pattern: Optional[str] = None,
163+
path: Optional[Path] = None,
164+
is_case_sensitive: bool = False,
165165
) -> str:
166166
if pattern is None:
167167
raise ValueError("pattern argument is required!")
@@ -183,18 +183,22 @@ def execute(
183183
paths = [p for p in path.iterdir() if p.is_file()]
184184

185185
from collections import defaultdict
186+
186187
file_matches = defaultdict(list)
187188
for path in paths:
188-
with path.open("r") as f:
189-
for i, line in enumerate(f.readlines()):
190-
if not matcher(line, pattern):
191-
continue
189+
try:
190+
with path.open("r") as f:
191+
for i, line in enumerate(f.readlines()):
192+
if not matcher(line, pattern):
193+
continue
192194

193-
content = f"Line {i + 1}: {line}"
194-
if len(line) > self.__CHAR_LIMIT:
195-
content = f"Line {i + 1}: {self.__CHAR_LIMIT_TEXT}"
195+
content = f"Line {i + 1}: {line}"
196+
if len(line) > self.__CHAR_LIMIT:
197+
content = f"Line {i + 1}: {self.__CHAR_LIMIT_TEXT}"
196198

197-
file_matches[str(path)].append(content)
199+
file_matches[str(path)].append(content)
200+
except Exception as e:
201+
pass
198202

199203
total_file_matches = ""
200204
for path_str, matches in file_matches.items():
@@ -207,4 +211,3 @@ def execute(
207211
for path_str, matches in file_matches.items():
208212
total_file_matches += f"\n {len(matches)} Pattern matches found in '{path}': <TRUNCATED>\n"
209213
return total_file_matches
210-

patchwork/steps/ExtractPackageManagerFile/TestExtractPackageManagerFile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import unittest
33
from pathlib import Path
44

5-
from patchwork.steps.ExtractPackageManagerFile.ExtractPackageManagerFile import ExtractPackageManagerFile
5+
from patchwork.steps.ExtractPackageManagerFile.ExtractPackageManagerFile import (
6+
ExtractPackageManagerFile,
7+
)
68

79

810
class TestExtractPackageManagerFile(unittest.TestCase):
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
from patchwork.common.client.llm.aio import AioLlmClient
5+
from patchwork.common.multiturn_strategy.agentic_strategy_v2 import (
6+
AgentConfig,
7+
AgenticStrategyV2,
8+
)
9+
from patchwork.common.tools import FileViewTool, FindTextTool
10+
from patchwork.common.tools.csvkit_tool import CSVSQLTool, In2CSVTool
11+
from patchwork.common.utils.utils import mustache_render
12+
from patchwork.step import Step
13+
from patchwork.steps.FileAgent.typed import FileAgentInputs, FileAgentOutputs
14+
15+
16+
class FileAgent(Step, input_class=FileAgentInputs, output_class=FileAgentOutputs):
17+
def __init__(self, inputs):
18+
super().__init__(inputs)
19+
self.base_path = inputs.get("base_path", str(Path.cwd()))
20+
data = inputs.get("prompt_value", {})
21+
task = mustache_render(inputs["task"], data)
22+
23+
self.strat_kwargs = dict(
24+
model="claude-3-5-sonnet-latest",
25+
llm_client=AioLlmClient.create_aio_client(inputs),
26+
template_data=dict(),
27+
system_prompt_template=f"""\
28+
Please summarise the conversation given and provide the result in the structure that is asked of you.
29+
""",
30+
user_prompt_template=f"""\
31+
Please help me with this task:
32+
33+
{task}
34+
""",
35+
agent_configs=[
36+
AgentConfig(
37+
name="Assistant",
38+
model="claude-3-7-sonnet-latest",
39+
tool_set=dict(),
40+
system_prompt="""\
41+
You are a assistant that is supposed to help me with a set of files. These files are commonly tabular formatted like csv, xls or xlsx.
42+
If you find a tabular formatted file you should use the `in2csv_tool` tool to convert the files into CSV.
43+
44+
After that is done, then run other tools to assist me.
45+
""",
46+
)
47+
],
48+
example_json=inputs.get("example_json"),
49+
)
50+
51+
def run(self) -> dict:
52+
kwargs = self.strat_kwargs
53+
with tempfile.TemporaryDirectory() as tmpdir:
54+
agent_config = next(iter(kwargs.get("agent_configs", [])), None)
55+
agent_config.tool_set = dict(
56+
find_text=FindTextTool(self.base_path),
57+
file_view=FileViewTool(self.base_path),
58+
in2csv_tool=In2CSVTool(self.base_path),
59+
csvsql_tool=CSVSQLTool(self.base_path, tmpdir),
60+
)
61+
agentic_strategy = AgenticStrategyV2(**kwargs)
62+
result = agentic_strategy.execute(limit=10)
63+
return {**result, **agentic_strategy.usage()}

patchwork/steps/FileAgent/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)