Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add copilot headers/auth for extracting package/ecosystem #314

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/codegate/llm_utils/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional

import structlog

Expand All @@ -24,6 +24,7 @@ async def extract_packages(
model: str = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
) -> List[str]:
"""Extract package names from the given content."""
system_prompt = Config.get_config().prompts.lookup_packages
Expand All @@ -35,6 +36,7 @@ async def extract_packages(
model=model,
api_key=api_key,
base_url=base_url,
extra_headers=extra_headers,
)

# Handle both formats: {"packages": [...]} and direct list [...]
Expand All @@ -49,6 +51,7 @@ async def extract_ecosystem(
model: str = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
) -> List[str]:
"""Extract ecosystem from the given content."""
system_prompt = Config.get_config().prompts.lookup_ecosystem
Expand All @@ -60,6 +63,7 @@ async def extract_ecosystem(
model=model,
api_key=api_key,
base_url=base_url,
extra_headers=extra_headers,
)

ecosystem = result if isinstance(result, str) else result.get("ecosystem")
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def complete(
model: str = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Expand All @@ -53,6 +54,7 @@ async def complete(
model,
api_key,
base_url,
extra_headers,
**kwargs,
)

Expand Down Expand Up @@ -102,6 +104,7 @@ async def _complete_litellm(
model: str,
api_key: str,
base_url: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Dict[str, Any]:
# Use the private method to create the request
Expand Down Expand Up @@ -134,6 +137,7 @@ async def _complete_litellm(
temperature=request["temperature"],
base_url=base_url,
response_format=request["response_format"],
extra_headers=extra_headers
)
content = response["choices"][0]["message"]["content"]

Expand Down
5 changes: 4 additions & 1 deletion src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ async def process_request(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
) -> PipelineResult:
"""Process a request through all pipeline steps"""
self.context.sensitive = PipelineSensitiveData(
Expand All @@ -235,6 +236,7 @@ async def process_request(
api_base=api_base,
)
self.context.metadata["prompt_id"] = prompt_id
self.context.metadata["extra_headers"] = extra_headers
current_request = request

for step in self.pipeline_steps:
Expand Down Expand Up @@ -271,9 +273,10 @@ async def process_request(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
instance = self.create_instance()
return await instance.process_request(
request, provider, prompt_id, model, api_key, api_base
request, provider, prompt_id, model, api_key, api_base, extra_headers
)
2 changes: 2 additions & 0 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def __lookup_packages(self, user_query: str, context: PipelineContext):
model=context.sensitive.model,
api_key=context.sensitive.api_key,
base_url=context.sensitive.api_base,
extra_headers=context.metadata.get('extra_headers', None),
)

logger.info(f"Packages in user query: {packages}")
Expand All @@ -79,6 +80,7 @@ async def __lookup_ecosystem(self, user_query: str, context: PipelineContext):
model=context.sensitive.model,
api_key=context.sensitive.api_key,
base_url=context.sensitive.api_base,
extra_headers=context.metadata.get('extra_headers', None),
)

logger.info(f"Ecosystem in user query: {ecosystem}")
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/pipeline/secrets/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ async def process(
Returns:
PipelineResult containing the processed request and context with redaction metadata
"""

if 'messages' not in request:
return PipelineResult(request=request, context=context)

secrets_manager = context.sensitive.manager
if not secrets_manager or not isinstance(secrets_manager, SecretsManager):
raise ValueError("Secrets manager not found in context")
Expand Down
19 changes: 17 additions & 2 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Dict

import structlog
from litellm.types.llms.openai import ChatCompletionRequest
Expand Down Expand Up @@ -41,6 +42,18 @@ def _request_id(headers: list[str]) -> str:
print("No request ID found in headers")
return ""

@staticmethod
def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:
copilot_header_names = ['copilot-integration-id', 'editor-plugin-version', 'editor-version',
'openai-intent', 'openai-organization', 'user-agent',
'vscode-machineid', 'vscode-sessionid', 'x-github-api-version',
'x-request-id']
copilot_headers = {}
for a_name in copilot_header_names:
copilot_headers[a_name] = headers.get(a_name, '')

return copilot_headers

async def process_body(self, headers: list[str], body: bytes) -> bytes:
"""Common processing logic for all strategies"""
try:
Expand All @@ -51,8 +64,10 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes:
request=normalized_body,
provider=self.provider_name,
prompt_id=self._request_id(headers),
model=normalized_body.get("model", ""),
api_key=None,
model=normalized_body.get("model", "gpt-4o-mini"),
api_key = headers.get('authorization','').replace('Bearer ', ''),
api_base = "https://" + headers.get('host', ''),
extra_headers=CopilotPipeline._get_copilot_headers(headers)
)

if result.request:
Expand Down
Loading