Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Use def in FastAPI dashboard calls. #222

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/codegate/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from fastapi import APIRouter

from codegate.dashboard.post_processing import (
match_conversations,
parse_get_alert_conversation,
parse_get_prompt_with_output,
parse_messages_in_conversations,
)
from codegate.dashboard.request_models import AlertConversation, Conversation
from codegate.db.connection import DbReader
Expand All @@ -19,31 +18,19 @@


@dashboard_router.get("/dashboard/messages")
async def get_messages() -> List[Conversation]:
def get_messages() -> List[Conversation]:
"""
Get all the messages from the database and return them as a list of conversations.
"""
prompts_outputs = await db_reader.get_prompts_with_output()
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output())

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
partial_conversations = [task.result() for task in tasks]

conversations = await match_conversations(partial_conversations)
return conversations
return asyncio.run(parse_messages_in_conversations(prompts_outputs))


@dashboard_router.get("/dashboard/alerts")
async def get_alerts() -> List[AlertConversation]:
def get_alerts() -> List[AlertConversation]:
"""
Get all the messages from the database and return them as a list of conversations.
"""
alerts_prompt_output = await db_reader.get_alerts_with_prompt_and_output()

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_alert_conversation(row)) for row in alerts_prompt_output]
alert_conversations = [task.result() for task in tasks if task.result() is not None]

return alert_conversations
alerts_prompt_output = asyncio.run(db_reader.get_alerts_with_prompt_and_output())
return asyncio.run(parse_get_alert_conversation(alerts_prompt_output))
41 changes: 39 additions & 2 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,23 @@ async def match_conversations(
return conversations


async def parse_get_alert_conversation(
async def parse_messages_in_conversations(
prompts_outputs: List[GetPromptWithOutputsRow],
) -> List[Conversation]:
"""
Get all the messages from the database and return them as a list of conversations.
"""

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
partial_conversations = [task.result() for task in tasks]

conversations = await match_conversations(partial_conversations)
return conversations


async def parse_row_alert_conversation(
row: GetAlertsWithPromptAndOutputRow,
) -> Optional[AlertConversation]:
"""
Expand All @@ -220,12 +236,33 @@ async def parse_get_alert_conversation(
conversation_timestamp=row.timestamp,
)
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
trigger_string = None
if row.trigger_string:
try:
trigger_string = json.loads(row.trigger_string)
except Exception:
trigger_string = row.trigger_string

return AlertConversation(
conversation=conversation,
alert_id=row.id,
code_snippet=code_snippet,
trigger_string=row.trigger_string,
trigger_string=trigger_string,
trigger_type=row.trigger_type,
trigger_category=row.trigger_category,
timestamp=row.timestamp,
)


async def parse_get_alert_conversation(
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
) -> List[AlertConversation]:
"""
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
AlertConversation

The rows contain the raw request and output strings from the pipeline.
"""
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
return [task.result() for task in tasks if task.result() is not None]
4 changes: 2 additions & 2 deletions src/codegate/dashboard/request_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -57,7 +57,7 @@ class AlertConversation(BaseModel):
conversation: Conversation
alert_id: str
code_snippet: Optional[CodeSnippet]
trigger_string: Optional[str]
trigger_string: Optional[Union[str, dict]]
trigger_type: str
trigger_category: Optional[str]
timestamp: datetime.datetime
4 changes: 4 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ async def process(
# Look for matches in vector DB using list of packages as filter
searched_objects = await self.get_objects_from_search(last_user_message_str, packages)

logger.info(
f"Found {len(searched_objects)} matches in the database",
searched_objects=searched_objects,
)
# If matches are found, add the matched content to context
if len(searched_objects) > 0:
# Remove searched objects that are not in packages. This is needed
Expand Down
4 changes: 1 addition & 3 deletions src/codegate/pipeline/system_prompt/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,4 @@ async def process(
context.add_alert(self.name, trigger_string=prepended_message)
request_system_message["content"] = prepended_message

return PipelineResult(
request=new_request,
)
return PipelineResult(request=new_request, context=context)
Loading