diff --git a/src/codegate/dashboard/dashboard.py b/src/codegate/dashboard/dashboard.py index 0cd6a136..2814c494 100644 --- a/src/codegate/dashboard/dashboard.py +++ b/src/codegate/dashboard/dashboard.py @@ -5,9 +5,8 @@ from fastapi import APIRouter from codegate.dashboard.post_processing import ( - match_conversations, parse_get_alert_conversation, - parse_get_prompt_with_output, + parse_messages_in_conversations, ) from codegate.dashboard.request_models import AlertConversation, Conversation from codegate.db.connection import DbReader @@ -19,31 +18,19 @@ @dashboard_router.get("/dashboard/messages") -async def get_messages() -> List[Conversation]: +def get_messages() -> List[Conversation]: """ Get all the messages from the database and return them as a list of conversations. """ - prompts_outputs = await db_reader.get_prompts_with_output() + prompts_outputs = asyncio.run(db_reader.get_prompts_with_output()) - # Parse the prompts and outputs in parallel - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs] - partial_conversations = [task.result() for task in tasks] - - conversations = await match_conversations(partial_conversations) - return conversations + return asyncio.run(parse_messages_in_conversations(prompts_outputs)) @dashboard_router.get("/dashboard/alerts") -async def get_alerts() -> List[AlertConversation]: +def get_alerts() -> List[AlertConversation]: """ Get all the messages from the database and return them as a list of conversations. """ - alerts_prompt_output = await db_reader.get_alerts_with_prompt_and_output() - - # Parse the prompts and outputs in parallel - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(parse_get_alert_conversation(row)) for row in alerts_prompt_output] - alert_conversations = [task.result() for task in tasks if task.result() is not None] - - return alert_conversations + alerts_prompt_output = asyncio.run(db_reader.get_alerts_with_prompt_and_output()) + return asyncio.run(parse_get_alert_conversation(alerts_prompt_output)) diff --git a/src/codegate/dashboard/post_processing.py b/src/codegate/dashboard/post_processing.py index 13c9217e..c2e1059b 100644 --- a/src/codegate/dashboard/post_processing.py +++ b/src/codegate/dashboard/post_processing.py @@ -200,7 +200,23 @@ async def match_conversations( return conversations -async def parse_get_alert_conversation( +async def parse_messages_in_conversations( + prompts_outputs: List[GetPromptWithOutputsRow], +) -> List[Conversation]: + """ + Get all the messages from the database and return them as a list of conversations. + """ + + # Parse the prompts and outputs in parallel + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs] + partial_conversations = [task.result() for task in tasks] + + conversations = await match_conversations(partial_conversations) + return conversations + + +async def parse_row_alert_conversation( row: GetAlertsWithPromptAndOutputRow, ) -> Optional[AlertConversation]: """ @@ -220,12 +236,33 @@ async def parse_get_alert_conversation( conversation_timestamp=row.timestamp, ) code_snippet = json.loads(row.code_snippet) if row.code_snippet else None + trigger_string = None + if row.trigger_string: + try: + trigger_string = json.loads(row.trigger_string) + except Exception: + trigger_string = row.trigger_string + return AlertConversation( conversation=conversation, alert_id=row.id, code_snippet=code_snippet, - trigger_string=row.trigger_string, + trigger_string=trigger_string, trigger_type=row.trigger_type, trigger_category=row.trigger_category, timestamp=row.timestamp, ) + + +async def parse_get_alert_conversation( + alerts_conversations: List[GetAlertsWithPromptAndOutputRow], +) -> List[AlertConversation]: + """ + Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of + AlertConversation + + The rows contain the raw request and output strings from the pipeline. + """ + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations] + return [task.result() for task in tasks if task.result() is not None] diff --git a/src/codegate/dashboard/request_models.py b/src/codegate/dashboard/request_models.py index 7750a2f4..d33e8732 100644 --- a/src/codegate/dashboard/request_models.py +++ b/src/codegate/dashboard/request_models.py @@ -1,5 +1,5 @@ import datetime -from typing import List, Optional +from typing import List, Optional, Union from pydantic import BaseModel @@ -57,7 +57,7 @@ class AlertConversation(BaseModel): conversation: Conversation alert_id: str code_snippet: Optional[CodeSnippet] - trigger_string: Optional[str] + trigger_string: Optional[Union[str, dict]] trigger_type: str trigger_category: Optional[str] timestamp: datetime.datetime diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 2fab586c..1c5a5617 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -88,6 +88,10 @@ async def process( # Look for matches in vector DB using list of packages as filter searched_objects = await self.get_objects_from_search(last_user_message_str, packages) + logger.info( + f"Found {len(searched_objects)} matches in the database", + searched_objects=searched_objects, + ) # If matches are found, add the matched content to context if len(searched_objects) > 0: # Remove searched objects that are not in packages. This is needed diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index 40e8c832..8a08ae2a 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -56,6 +56,4 @@ async def process( context.add_alert(self.name, trigger_string=prepended_message) request_system_message["content"] = prepended_message - return PipelineResult( - request=new_request, - ) + return PipelineResult(request=new_request, context=context)