Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit f79cbac

Browse files
committed
Code snippet extraction pipeline step
Adds a new pipeline step that extracts code snippets and saves them into the context. The code snippets are just saved to the context, we don't do anything with them yet. Fixes: #46
1 parent b58e5aa commit f79cbac

File tree

5 files changed

+387
-7
lines changed

5 files changed

+387
-7
lines changed

src/codegate/pipeline/base.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@ class CodeSnippet:
1515
code: The actual code content
1616
"""
1717

18-
language: str
18+
language: Optional[str]
19+
filepath: Optional[str]
1920
code: str
2021

2122
def __post_init__(self):
22-
if not self.language or not self.language.strip():
23-
raise ValueError("Language must not be empty")
24-
if not self.code or not self.code.strip():
25-
raise ValueError("Code must not be empty")
26-
self.language = self.language.strip().lower()
23+
if self.language is not None:
24+
self.language = self.language.strip().lower()
2725

2826

2927
@dataclass
@@ -57,6 +55,7 @@ class PipelineResult:
5755

5856
request: Optional[ChatCompletionRequest] = None
5957
response: Optional[PipelineResponse] = None
58+
context: Optional[PipelineContext] = None
6059
error_message: Optional[str] = None
6160

6261
def shortcuts_processing(self) -> bool:
@@ -165,4 +164,7 @@ async def process_request(
165164
if result.request is not None:
166165
current_request = result.request
167166

168-
return PipelineResult(request=current_request)
167+
if result.context is not None:
168+
context = result.context
169+
170+
return PipelineResult(request=current_request, context=context)

src/codegate/pipeline/extract_snippets/__init__.py

Whitespace-only changes.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
import re
3+
from typing import List, Optional
4+
5+
import structlog
6+
from litellm.types.llms.openai import ChatCompletionRequest
7+
8+
from codegate.pipeline.base import CodeSnippet, PipelineContext, PipelineResult, PipelineStep
9+
10+
CODE_BLOCK_PATTERN = re.compile(
11+
r"```(?:(?P<language>\w+)\s+)?(?P<filename>[^\s\(]+)?(?:\s*\((?P<lineinfo>[^)]+)\))?\n(?P<content>(?:.|\n)*?)```"
12+
)
13+
14+
logger = structlog.get_logger("codegate")
15+
16+
def ecosystem_from_filepath(filepath: str) -> Optional[str]:
17+
"""
18+
Determine language from filepath.
19+
20+
Args:
21+
filepath: Path to the file
22+
23+
Returns:
24+
Determined language based on file extension
25+
"""
26+
# Implement file extension to language mapping
27+
extension_mapping = {
28+
".py": "python",
29+
".js": "javascript",
30+
".ts": "typescript",
31+
".tsx": "typescript",
32+
".go": "go",
33+
".rs": "rust",
34+
".java": "java",
35+
}
36+
37+
# Get the file extension
38+
ext = os.path.splitext(filepath)[1].lower()
39+
return extension_mapping.get(ext, None)
40+
41+
42+
def ecosystem_from_message(message: str) -> Optional[str]:
43+
"""
44+
Determine language from message.
45+
46+
Args:
47+
message: The language from the message. Some extensions send a different
48+
format where the language is present in the snippet,
49+
e.g. "py /path/to/file (lineFrom-lineTo)"
50+
51+
Returns:
52+
Determined language based on message content
53+
"""
54+
language_mapping = {
55+
"py": "python",
56+
"js": "javascript",
57+
"ts": "typescript",
58+
"tsx": "typescript",
59+
"go": "go",
60+
}
61+
return language_mapping.get(message, None)
62+
63+
64+
def extract_snippets(message: str) -> List[CodeSnippet]:
65+
"""
66+
Extract code snippets from a message.
67+
68+
Args:
69+
message: Input text containing code snippets
70+
71+
Returns:
72+
List of extracted code snippets
73+
"""
74+
# Regular expression to find code blocks
75+
76+
snippets: List[CodeSnippet] = []
77+
78+
# Find all code block matches
79+
for match in CODE_BLOCK_PATTERN.finditer(message):
80+
filename = match.group("filename")
81+
content = match.group("content")
82+
matched_language = match.group("language")
83+
84+
# Determine language
85+
lang = None
86+
if matched_language:
87+
lang = ecosystem_from_message(matched_language.strip())
88+
if lang is None and filename:
89+
filename = filename.strip()
90+
# Determine language from the filename
91+
lang = ecosystem_from_filepath(filename)
92+
93+
snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))
94+
95+
return snippets
96+
97+
98+
class CodeSnippetExtractor(PipelineStep):
99+
"""
100+
Pipeline step that merely extracts code snippets from the user message.
101+
"""
102+
103+
def __init__(self):
104+
"""Initialize the CodeSnippetExtractor pipeline step."""
105+
super().__init__()
106+
107+
@property
108+
def name(self) -> str:
109+
return "code-snippet-extractor"
110+
111+
async def process(
112+
self,
113+
request: ChatCompletionRequest,
114+
context: PipelineContext,
115+
) -> PipelineResult:
116+
last_user_message = self.get_last_user_message(request)
117+
if not last_user_message:
118+
return PipelineResult(request=request, context=context)
119+
msg_content, _ = last_user_message
120+
snippets = extract_snippets(msg_content)
121+
122+
logger.info(f"Extracted {len(snippets)} code snippets from the user message")
123+
124+
if len(snippets) > 0:
125+
for snippet in snippets:
126+
logger.debug(f"Code snippet: {snippet}")
127+
context.add_code_snippet(snippet)
128+
129+
return PipelineResult(
130+
context=context,
131+
)

src/codegate/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from codegate.config import Config
77
from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor
88
from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt
9+
from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor
910
from codegate.pipeline.version.version import CodegateVersion
1011
from codegate.providers.anthropic.provider import AnthropicProvider
1112
from codegate.providers.llamacpp.provider import LlamaCppProvider
@@ -23,6 +24,7 @@ def init_app() -> FastAPI:
2324

2425
steps: List[PipelineStep] = [
2526
CodegateVersion(),
27+
CodeSnippetExtractor(),
2628
CodegateSystemPrompt(Config.get_config().prompts.codegate_chat),
2729
# CodegateSecrets(),
2830
]

0 commit comments

Comments
 (0)