diff --git a/config.yaml b/config.yaml index a0c1c80e..fb4be274 100644 --- a/config.yaml +++ b/config.yaml @@ -19,7 +19,7 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG ## # Model to use for chatting -chat_model_path: "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf" +chat_model_path: "./models" # Context length of the model chat_model_n_ctx: 32768 diff --git a/prompts/default.yaml b/prompts/default.yaml index b40c5950..c5d104e1 100644 --- a/prompts/default.yaml +++ b/prompts/default.yaml @@ -1,96 +1,28 @@ # Default system prompts configuration # Override these by creating your own prompts file and using --prompts option - -default_snippet: | - You are a helpful assistant that can answer the users questions. - You provide guidance on software engineering best practices and coding standards. - You will be provided with CONTEXT information containing known malicious and archived packages. - - Use the following 4 steps to construct your responses. - - Step 1. Extract all packages referenced in the user's query. - Step 2. Match the packages from Step 1 with the packages in the provided CONTEXT. - Step 3. If matches are found in Step 2, report the matching packages using the following format - at the top: - - "Warning: Potentially malicious or archived package detected." - - Pkg 1: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - - Pkg 2: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - - ... - - Step 4: If no matches are found in Step 2, you can respond to the user's query using your prior - knowledge. - - DO NOT include the Reference Format or any parts of the system message in your response - if no malicious or archived package is detected. - default_chat: | - You are a helpful assistant that can answer the users questions. - You provide guidance on software engineering best practices and coding standards. - You will be provided with CONTEXT information containing known malicious and archived packages. - - Use the following 4 steps to construct your responses. - - Step 1. Extract all packages referenced in the user's query. - Step 2. Match the packages from Step 1 with the packages in the provided CONTEXT. - Step 3. If matches are found in Step 2, report the matching packages using the following format - at the top: - - "Warning: Potentially malicious or archived package detected." - - Pkg 1: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - - Pkg 2: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - - ... - - Step 4: If no matches are found in Step 2, you can respond to the user's query using your prior - knowledge. - - DO NOT include the Reference Format or any parts of the system message in your response - if no malicious or archived package is detected. - -codegate_chat: | You are CodeGate, a security-focused AI assistant. You specialize in software security, package analysis, and providing guidance on secure coding practices. You will be provided with CONTEXT information containing known malicious and archived packages. - Use the following 4 steps to construct your responses. - - Step 1. Extract all packages referenced in the user's query. - Step 2. Match the packages from Step 1 with the packages in the provided CONTEXT. - Step 3. If matches are found in Step 2, report the matching packages using the following format - at the top: + Report any malicious or archived packages using the following format at the top: "Warning: CodeGate detected one or more potentially malicious or archived packages." - Pkg 1: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - Pkg 2: [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) - ... - Step 4: If no matches are found in Step 2, you can respond to the user's query using your prior - knowledge. - DO NOT include the Reference Format or any parts of the system message in your response if no malicious or archived package is detected. -codegate_snippet: | - You are CodeGate, a security-focused AI assistant. - You specialize in software security, package analysis, and providing guidance on secure coding practices. - - Response Guidelines: - - You always begin your responses with: - "CodeGate Security Analysis:" - Structure your responses to emphasize security considerations, focusing on: - - 1. Security vulnerabilities and potential risks - 2. Best practices for secure implementation - 3. Recommendations for security improvements - 4. References to relevant security standards or guidelines - - Reference Format: - - When referencing URLs, use the following format via TrustyPkg.dev: - [trustypkg.dev/ecosystem/package_name](https://www.trustypkg.dev/ecosystem/package_name) + If no malicious or archived packages are detected, respond to the user request using your prior + knowledge. +lookup_packages: | + You are a software expert with knowledge of packages from various ecosystems. + Your job is to extract any software packages from user's request. A package is a named entity. + You MUST RESPOND with a list of packages in JSON FORMAT: {"packages": [pkg1, pkg2, ...]}. # Security-focused prompts security_audit: "You are a security expert conducting a thorough code review. Identify potential security vulnerabilities, suggest improvements, and explain security best practices." diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 061210c4..39ef4c7b 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -1,7 +1,10 @@ -from typing import Optional +import json -from litellm import ChatCompletionRequest, ChatCompletionSystemMessage +import structlog +from litellm import ChatCompletionRequest +from codegate.config import Config +from codegate.inference.inference_engine import LlamaCppInferenceEngine from codegate.pipeline.base import ( PipelineContext, PipelineResult, @@ -10,6 +13,8 @@ from src.codegate.storage.storage_engine import StorageEngine from src.codegate.utils.utils import generate_vector_string +logger = structlog.get_logger("codegate") + class CodegateContextRetriever(PipelineStep): """ @@ -17,11 +22,9 @@ class CodegateContextRetriever(PipelineStep): the word "codegate" in the user message. """ - def __init__(self, system_prompt_message: Optional[str] = None): - self._system_message = ChatCompletionSystemMessage( - content=system_prompt_message, role="system" - ) + def __init__(self): self.storage_engine = StorageEngine() + self.inference_engine = LlamaCppInferenceEngine() @property def name(self) -> str: @@ -30,8 +33,10 @@ def name(self) -> str: """ return "codegate-context-retriever" - async def get_objects_from_search(self, search: str) -> list[object]: - objects = await self.storage_engine.search(search, distance=0.5) + async def get_objects_from_search( + self, search: str, packages: list[str] = None + ) -> list[object]: + objects = await self.storage_engine.search(search, distance=0.8, packages=packages) return objects def generate_context_str(self, objects: list[object]) -> str: @@ -48,49 +53,87 @@ def generate_context_str(self, objects: list[object]) -> str: context_str += package_str + "\n" return context_str + async def __lookup_packages(self, user_query: str): + ## Check which packages are referenced in the user query + request = { + "messages": [ + {"role": "system", "content": Config.get_config().prompts.lookup_packages}, + {"role": "user", "content": user_query}, + ], + "model": "qwen2-1_5b-instruct-q5_k_m", + "stream": False, + "response_format": {"type": "json_object"}, + "temperature": 0, + } + + result = await self.inference_engine.chat( + f"{Config.get_config().model_base_path}/{request['model']}.gguf", + n_ctx=Config.get_config().chat_model_n_ctx, + n_gpu_layers=Config.get_config().chat_model_n_gpu_layers, + **request, + ) + + result = json.loads(result["choices"][0]["message"]["content"]) + logger.info(f"Packages in user query: {result['packages']}") + return result["packages"] + async def process( self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: """ - Process the completion request and add a system prompt if the user message contains - the word "codegate". + Use RAG DB to add context to the user request """ - # no prompt configured - if not self._system_message["content"]: - return PipelineResult(request=request) + # Get the last user message last_user_message = self.get_last_user_message(request) - if last_user_message is not None: - last_user_message_str, last_user_idx = last_user_message - if last_user_message_str.lower(): - # Look for matches in vector DB - searched_objects = await self.get_objects_from_search(last_user_message_str) - - # If matches are found, add the matched content to context - if len(searched_objects) > 0: - context_str = self.generate_context_str(searched_objects) - - # Make a copy of the request - new_request = request.copy() - - # Add the context to the last user message - # Format: "Context: {context_str} \n Query: {last user message conent}" - # Handle the two cases: (a) message content is str, (b)message content - # is list - message = new_request["messages"][last_user_idx] - if isinstance(message["content"], str): - message["content"] = ( - f'Context: {context_str} \n\n Query: {message["content"]}' - ) - elif isinstance(message["content"], (list, tuple)): - for item in message["content"]: - if isinstance(item, dict) and item.get("type") == "text": - item["text"] = f'Context: {context_str} \n\n Query: {item["text"]}' - - return PipelineResult( - request=new_request, - ) + # Nothing to do if the last user message is none + if last_user_message is None: + return PipelineResult(request=request) + + # Extract packages from the user message + last_user_message_str, last_user_idx = last_user_message + packages = await self.__lookup_packages(last_user_message_str) + + # If user message does not reference any packages, then just return + if len(packages) == 0: + return PipelineResult(request=request) + + # Look for matches in vector DB using list of packages as filter + searched_objects = await self.get_objects_from_search(last_user_message_str, packages) + + # If matches are found, add the matched content to context + if len(searched_objects) > 0: + # Remove searched objects that are not in packages. This is needed + # since Weaviate performs substring match in the filter. + updated_searched_objects = [] + for searched_object in searched_objects: + if searched_object.properties["name"] in packages: + updated_searched_objects.append(searched_object) + searched_objects = updated_searched_objects + + # Generate context string using the searched objects + logger.info(f"Adding {len(searched_objects)} packages to the context") + context_str = self.generate_context_str(searched_objects) + + # Make a copy of the request + new_request = request.copy() + + # Add the context to the last user message + # Format: "Context: {context_str} \n Query: {last user message conent}" + # Handle the two cases: (a) message content is str, (b)message content + # is list + message = new_request["messages"][last_user_idx] + if isinstance(message["content"], str): + message["content"] = f'Context: {context_str} \n\n Query: {message["content"]}' + elif isinstance(message["content"], (list, tuple)): + for item in message["content"]: + if isinstance(item, dict) and item.get("type") == "text": + item["text"] = f'Context: {context_str} \n\n Query: {item["text"]}' + + return PipelineResult( + request=new_request, + ) # Fall through return PipelineResult(request=request) diff --git a/src/codegate/pipeline/codegate_system_prompt/__init__.py b/src/codegate/pipeline/codegate_system_prompt/__init__.py deleted file mode 100644 index 221f2a3b..00000000 --- a/src/codegate/pipeline/codegate_system_prompt/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt - -__all__ = ["CodegateSystemPrompt"] diff --git a/src/codegate/pipeline/codegate_system_prompt/codegate.py b/src/codegate/pipeline/codegate_system_prompt/codegate.py deleted file mode 100644 index 9659bb93..00000000 --- a/src/codegate/pipeline/codegate_system_prompt/codegate.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Optional - -from litellm import ChatCompletionRequest, ChatCompletionSystemMessage - -from codegate.pipeline.base import ( - PipelineContext, - PipelineResult, - PipelineStep, -) - - -class CodegateSystemPrompt(PipelineStep): - """ - Pipeline step that adds a system prompt to the completion request when it detects - the word "codegate" in the user message. - """ - - def __init__(self, system_prompt_message: Optional[str] = None): - self._system_message = ChatCompletionSystemMessage( - content=system_prompt_message, role="system" - ) - - @property - def name(self) -> str: - """ - Returns the name of this pipeline step. - """ - return "codegate-system-prompt" - - async def process( - self, request: ChatCompletionRequest, context: PipelineContext - ) -> PipelineResult: - """ - Process the completion request and add a system prompt if the user message contains - the word "codegate". - """ - # no prompt configured - if not self._system_message["content"]: - return PipelineResult(request=request) - - last_user_message = self.get_last_user_message(request) - - if last_user_message is not None: - last_user_message_str, last_user_idx = last_user_message - if "codegate" in last_user_message_str.lower(): - # Add a system prompt to the completion request - new_request = request.copy() - new_request["messages"].insert(last_user_idx, self._system_message) - return PipelineResult( - request=new_request, - ) - - # Fall through - return PipelineResult(request=request) diff --git a/src/codegate/pipeline/system_prompt/__init__.py b/src/codegate/pipeline/system_prompt/__init__.py new file mode 100644 index 00000000..cbba646e --- /dev/null +++ b/src/codegate/pipeline/system_prompt/__init__.py @@ -0,0 +1,3 @@ +from codegate.pipeline.system_prompt.codegate import SystemPrompt + +__all__ = ["SystemPrompt"] diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py new file mode 100644 index 00000000..475e5682 --- /dev/null +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -0,0 +1,57 @@ +from typing import Optional + +from litellm import ChatCompletionRequest, ChatCompletionSystemMessage + +from codegate.pipeline.base import ( + PipelineContext, + PipelineResult, + PipelineStep, +) + + +class SystemPrompt(PipelineStep): + """ + Pipeline step that adds a system prompt to the completion request when it detects + the word "codegate" in the user message. + """ + + def __init__(self, system_prompt: str): + self._system_message = ChatCompletionSystemMessage( + content=system_prompt, role="system" + ) + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + """ + return "system-prompt" + + async def process( + self, request: ChatCompletionRequest, context: PipelineContext + ) -> PipelineResult: + """ + Add system prompt if not present, otherwise prepend codegate system prompt + to the existing system prompt + """ + new_request = request.copy() + + if "messages" not in new_request: + new_request["messages"] = [] + + request_system_message = None + for message in new_request["messages"]: + if message["role"] == "system": + request_system_message = message + + if request_system_message is None: + # Add system message + new_request["messages"].insert(0, self._system_message) + elif "codegate" not in request_system_message["content"].lower(): + # Prepend to the system message + request_system_message["content"] = self._system_message["content"] + \ + "\n Here are additional instructions. \n " + request_system_message["content"] + + return PipelineResult( + request=new_request, + ) diff --git a/src/codegate/server.py b/src/codegate/server.py index efae9cb8..a86c06e0 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -7,7 +7,7 @@ from codegate.dashboard.dashboard import dashboard_router from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever -from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt +from codegate.pipeline.system_prompt.codegate import SystemPrompt from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor from codegate.pipeline.extract_snippets.output import CodeCommentStep from codegate.pipeline.output import OutputPipelineProcessor, OutputPipelineStep @@ -40,8 +40,8 @@ def init_app() -> FastAPI: steps: List[PipelineStep] = [ CodegateVersion(), CodeSnippetExtractor(), - #CodegateSystemPrompt(Config.get_config().prompts.codegate_chat), - CodegateContextRetriever(Config.get_config().prompts.codegate_chat), + SystemPrompt(Config.get_config().prompts.default_chat), + CodegateContextRetriever(), CodegateSecrets(), ] # Leaving the pipeline empty for now diff --git a/src/codegate/storage/storage_engine.py b/src/codegate/storage/storage_engine.py index 8d2683bc..500d74a1 100644 --- a/src/codegate/storage/storage_engine.py +++ b/src/codegate/storage/storage_engine.py @@ -3,6 +3,7 @@ from weaviate.classes.config import DataType from weaviate.classes.query import MetadataQuery from weaviate.embedded import EmbeddedOptions +import weaviate.classes as wvc from codegate.config import Config from codegate.inference.inference_engine import LlamaCppInferenceEngine @@ -87,7 +88,7 @@ def setup_schema(self, client): ) logger.info(f"Weaviate schema for class {class_config['name']} setup complete.") - async def search(self, query: str, limit=5, distance=0.3) -> list[object]: + async def search(self, query: str, limit=5, distance=0.3, packages=None) -> list[object]: """ Search the 'Package' collection based on a query string. @@ -110,12 +111,21 @@ async def search(self, query: str, limit=5, distance=0.3) -> list[object]: try: weaviate_client.connect() collection = weaviate_client.collections.get("Package") - response = collection.query.near_vector( - query_vector[0], - limit=limit, - distance=distance, - return_metadata=MetadataQuery(distance=True), - ) + if packages: + response = collection.query.near_vector( + query_vector[0], + limit=limit, + distance=distance, + filters=wvc.query.Filter.by_property("name").contains_any(packages), + return_metadata=MetadataQuery(distance=True), + ) + else: + response = collection.query.near_vector( + query_vector[0], + limit=limit, + distance=distance, + return_metadata=MetadataQuery(distance=True), + ) weaviate_client.close() if not response: diff --git a/tests/pipeline/codegate_system_prompt/test_codegate_system_prompt.py b/tests/pipeline/codegate_system_prompt/test_codegate_system_prompt.py deleted file mode 100644 index f7f8e8be..00000000 --- a/tests/pipeline/codegate_system_prompt/test_codegate_system_prompt.py +++ /dev/null @@ -1,106 +0,0 @@ -from unittest.mock import Mock - -import pytest -from litellm.types.llms.openai import ChatCompletionRequest - -from codegate.pipeline.base import PipelineContext -from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt - - -@pytest.mark.asyncio -class TestCodegateSystemPrompt: - def test_init_no_system_message(self): - """ - Test initialization with no system message - """ - step = CodegateSystemPrompt() - assert step._system_message["content"] is None - - def test_init_with_system_message(self): - """ - Test initialization with a system message - """ - test_message = "Test system prompt" - step = CodegateSystemPrompt(system_prompt_message=test_message) - assert step._system_message["content"] == test_message - - @pytest.mark.parametrize( - "user_message,expected_modification", - [ - # Test cases with different scenarios - ("Hello CodeGate", True), - ("CODEGATE in uppercase", True), - ("No matching message", False), - ("codegate with lowercase", True), - ], - ) - async def test_process_system_prompt_insertion(self, user_message, expected_modification): - """ - Test system prompt insertion based on message content - """ - # Prepare mock request with user message - mock_request = {"messages": [{"role": "user", "content": user_message}]} - mock_context = Mock(spec=PipelineContext) - - # Create system prompt step - system_prompt = "Security analysis system prompt" - step = CodegateSystemPrompt(system_prompt_message=system_prompt) - - # Mock the get_last_user_message method - step.get_last_user_message = Mock(return_value=(user_message, 0)) - - # Process the request - result = await step.process(ChatCompletionRequest(**mock_request), mock_context) - - if expected_modification: - # Check that system message was inserted - assert len(result.request["messages"]) == 2 - assert result.request["messages"][0]["role"] == "system" - assert result.request["messages"][0]["content"] == system_prompt - assert result.request["messages"][1]["role"] == "user" - assert result.request["messages"][1]["content"] == user_message - else: - # Ensure no modification occurred - assert len(result.request["messages"]) == 1 - - async def test_no_system_message_configured(self): - """ - Test behavior when no system message is configured - """ - mock_request = {"messages": [{"role": "user", "content": "CodeGate test"}]} - mock_context = Mock(spec=PipelineContext) - - # Create step without system message - step = CodegateSystemPrompt() - - # Process the request - result = await step.process(ChatCompletionRequest(**mock_request), mock_context) - - # Verify request remains unchanged - assert result.request == mock_request - - @pytest.mark.parametrize( - "edge_case", - [ - None, # No messages - [], # Empty messages list - ], - ) - async def test_edge_cases(self, edge_case): - """ - Test edge cases with None or empty message list - """ - mock_request = {"messages": edge_case} if edge_case is not None else {} - mock_context = Mock(spec=PipelineContext) - - system_prompt = "Security edge case prompt" - step = CodegateSystemPrompt(system_prompt_message=system_prompt) - - # Mock get_last_user_message to return None - step.get_last_user_message = Mock(return_value=None) - - # Process the request - result = await step.process(ChatCompletionRequest(**mock_request), mock_context) - - # Verify request remains unchanged - assert result.request == mock_request diff --git a/tests/pipeline/system_prompt/test_system_prompt.py b/tests/pipeline/system_prompt/test_system_prompt.py new file mode 100644 index 00000000..06f92733 --- /dev/null +++ b/tests/pipeline/system_prompt/test_system_prompt.py @@ -0,0 +1,109 @@ +from unittest.mock import Mock + +import pytest +from litellm.types.llms.openai import ChatCompletionRequest + +from codegate.pipeline.base import PipelineContext +from codegate.pipeline.system_prompt.codegate import SystemPrompt + + +class TestSystemPrompt: + def test_init_with_system_message(self): + """ + Test initialization with a system message + """ + test_message = "Test system prompt" + step = SystemPrompt(system_prompt=test_message) + assert step._system_message["content"] == test_message + + @pytest.mark.asyncio + async def test_process_system_prompt_insertion(self): + """ + Test system prompt insertion based on message content + """ + # Prepare mock request with user message + user_message = "Test user message" + mock_request = {"messages": [{"role": "user", "content": user_message}]} + mock_context = Mock(spec=PipelineContext) + + # Create system prompt step + system_prompt = "Security analysis system prompt" + step = SystemPrompt(system_prompt=system_prompt) + + # Mock the get_last_user_message method + step.get_last_user_message = Mock(return_value=(user_message, 0)) + + # Process the request + result = await step.process(ChatCompletionRequest(**mock_request), mock_context) + + # Check that system message was inserted + assert len(result.request["messages"]) == 2 + assert result.request["messages"][0]["role"] == "system" + assert result.request["messages"][0]["content"] == system_prompt + assert result.request["messages"][1]["role"] == "user" + assert result.request["messages"][1]["content"] == user_message + + @pytest.mark.asyncio + async def test_process_system_prompt_update(self): + """ + Test system prompt update + """ + # Prepare mock request with user message + request_system_message = "Existing system message" + user_message = "Test user message" + mock_request = { + "messages": [ + {"role": "system", "content": request_system_message}, + {"role": "user", "content": user_message}, + ] + } + mock_context = Mock(spec=PipelineContext) + + # Create system prompt step + system_prompt = "Security analysis system prompt" + step = SystemPrompt(system_prompt=system_prompt) + + # Mock the get_last_user_message method + step.get_last_user_message = Mock(return_value=(user_message, 0)) + + # Process the request + result = await step.process(ChatCompletionRequest(**mock_request), mock_context) + + # Check that system message was inserted + assert len(result.request["messages"]) == 2 + assert result.request["messages"][0]["role"] == "system" + assert ( + result.request["messages"][0]["content"] + == system_prompt + "\n Here are additional instructions. \n " + request_system_message + ) + assert result.request["messages"][1]["role"] == "user" + assert result.request["messages"][1]["content"] == user_message + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "edge_case", + [ + None, # No messages + [], # Empty messages list + ], + ) + async def test_edge_cases(self, edge_case): + """ + Test edge cases with None or empty message list + """ + mock_request = {"messages": edge_case} if edge_case is not None else {} + mock_context = Mock(spec=PipelineContext) + + system_prompt = "Security edge case prompt" + step = SystemPrompt(system_prompt=system_prompt) + + # Mock get_last_user_message to return None + step.get_last_user_message = Mock(return_value=None) + + # Process the request + result = await step.process(ChatCompletionRequest(**mock_request), mock_context) + + # Verify request remains unchanged + assert len(result.request["messages"]) == 1 + assert result.request["messages"][0]["role"] == "system" + assert result.request["messages"][0]["content"] == system_prompt diff --git a/tests/test_cli.py b/tests/test_cli.py index e5ed7e98..d72e562b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -74,7 +74,7 @@ def test_serve_default_options( "port": 8989, "log_level": "INFO", "log_format": "JSON", - "prompts_loaded": 7, # Default prompts are loaded + "prompts_loaded": 5, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -113,7 +113,7 @@ def test_serve_custom_options( "port": 8989, "log_level": "DEBUG", "log_format": "TEXT", - "prompts_loaded": 7, # Default prompts are loaded + "prompts_loaded": 5, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -153,7 +153,7 @@ def test_serve_with_config_file( "port": 8989, "log_level": "DEBUG", "log_format": "JSON", - "prompts_loaded": 7, # Default prompts are loaded + "prompts_loaded": 5, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) @@ -205,7 +205,7 @@ def test_serve_priority_resolution( "port": 8080, "log_level": "ERROR", "log_format": "TEXT", - "prompts_loaded": 7, # Default prompts are loaded + "prompts_loaded": 5, # Default prompts are loaded "provider_urls": DEFAULT_PROVIDER_URLS, }, ) diff --git a/tests/test_cli_prompts.py b/tests/test_cli_prompts.py index 2b5029a8..7f73567c 100644 --- a/tests/test_cli_prompts.py +++ b/tests/test_cli_prompts.py @@ -39,9 +39,6 @@ def test_show_default_prompts(): assert result.exit_code == 0 assert "Loaded prompts:" in result.output assert "default_chat:" in result.output - assert "default_snippet:" in result.output - assert "codegate_chat:" in result.output - assert "codegate_snippet:" in result.output assert "security_audit:" in result.output assert "red_team:" in result.output assert "blue_team:" in result.output diff --git a/tests/test_prompts.py b/tests/test_prompts.py index e28863f0..9adf4152 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -68,7 +68,7 @@ def test_default_prompts(): config = Config.load() assert len(config.prompts.prompts) > 0 assert hasattr(config.prompts, "default_chat") - assert "You are a helpful assistant" in config.prompts.default_chat + assert "You are CodeGate" in config.prompts.default_chat def test_cli_prompts_override_default(temp_prompts_file):