Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Use ollama python client for completion #241

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
869 changes: 443 additions & 426 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
authors = []

[tool.poetry.dependencies]
python = ">=3.11"
python = ">=3.11,<4.0"
click = ">=8.1.0"
PyYAML = ">=6.0.1"
fastapi = ">=0.115.5"
Expand All @@ -19,6 +19,7 @@ cryptography = "^44.0.0"
sqlalchemy = "^2.0.28"
greenlet = "^3.0.3"
aiosqlite = "^0.20.0"
ollama = ">=0.4.4"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.0"
Expand Down
1 change: 1 addition & 0 deletions src/codegate/codegate_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def setup_logging(
structlog.processors.CallsiteParameterAdder(
[
structlog.processors.CallsiteParameter.MODULE,
structlog.processors.CallsiteParameter.PATHNAME,
]
),
]
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"openai": "https://api.openai.com/v1",
"anthropic": "https://api.anthropic.com/v1",
"vllm": "http://localhost:8000", # Base URL without /v1 path
"ollama": "http://localhost:11434/api", # Default Ollama server URL
"ollama": "http://localhost:11434", # Default Ollama server URL
}


Expand Down
33 changes: 21 additions & 12 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, Optional

import structlog
from litellm import acompletion
from litellm import acompletion, completion

from codegate.config import Config
from codegate.inference import LlamaCppInferenceEngine
Expand Down Expand Up @@ -112,18 +112,27 @@ async def _complete_litellm(
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
else:
model = f"{provider}/{model}"
if not model.startswith(f"{provider}/"):
model = f"{provider}/{model}"

try:
response = await acompletion(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
response_format=request["response_format"],
)

if provider == "ollama":
response = completion(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
)
else:
response = await acompletion(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
response_format=request["response_format"],
)
content = response["choices"][0]["message"]["content"]

# Clean up code blocks if present
Expand All @@ -133,5 +142,5 @@ async def _complete_litellm(
return json.loads(content)

except Exception as e:
logger.error(f"LiteLLM completion failed: {e}")
logger.error(f"LiteLLM completion failed {provider}/{model} ({content}): {e}")
return {}
1 change: 0 additions & 1 deletion src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ async def process_stream(

# Yield all processed chunks
for c in current_chunks:
logger.debug(f"Yielding chunk {c}")
self._store_chunk_content(c)
self._context.buffer.clear()
yield c
Expand Down
10 changes: 4 additions & 6 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,18 @@ async def complete(
provider-specific format
"""
normalized_request = self._input_normalizer.normalize(data)
streaming = data.get("stream", False)
streaming = normalized_request.get("stream", False)
prompt_db = await self._db_recorder.record_request(
normalized_request, is_fim_request, self.provider_route_name
)

prompt_db_id = prompt_db.id if prompt_db is not None else None
input_pipeline_result = await self._run_input_pipeline(
normalized_request,
api_key,
data.get("base_url"),
is_fim_request,
prompt_id=prompt_db.id,
prompt_id=prompt_db_id,
)
if input_pipeline_result.response:
await self._db_recorder.record_alerts(input_pipeline_result.context.alerts_raised)
Expand All @@ -239,7 +240,6 @@ async def complete(
# Execute the completion and translate the response
# This gives us either a single response or a stream of responses
# based on the streaming flag
logger.info(f"Executing completion with {provider_request}")
model_response = await self._completion_handler.execute_completion(
provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request
)
Expand All @@ -259,9 +259,7 @@ async def complete(

model_response = self._db_recorder.record_output_stream(prompt_db, model_response)
pipeline_output_stream = await self._run_output_stream_pipeline(
input_pipeline_result.context,
model_response,
is_fim_request=is_fim_request,
input_pipeline_result.context, model_response, is_fim_request=is_fim_request
)
return self._cleanup_after_streaming(pipeline_output_stream, input_pipeline_result.context)

Expand Down
125 changes: 100 additions & 25 deletions src/codegate/providers/ollama/adapter.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
from typing import Any, Dict
import uuid
from datetime import datetime, timezone
from typing import Any, AsyncIterator, Dict, Union

from litellm import ChatCompletionRequest
from litellm import ChatCompletionRequest, ModelResponse
from litellm.types.utils import Delta, StreamingChoices
from ollama import ChatResponse, Message

from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer


class OllamaInputNormalizer(ModelInputNormalizer):
def __init__(self):
super().__init__()

def normalize(self, data: Dict) -> ChatCompletionRequest:
"""
Normalize the input data to the format expected by Ollama.
"""
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)
normalized_data["model"] = data.get("model", "").strip()
normalized_data["options"] = data.get("options", {})

# Add any context or system prompt if provided
if "context" in data:
normalized_data["context"] = data["context"]
if "system" in data:
normalized_data["system"] = data["system"]
if "prompt" in normalized_data:
normalized_data["messages"] = [
{"content": normalized_data.pop("prompt"), "role": "user"}
]

# Format the model name
if "model" in normalized_data:
normalized_data["model"] = data["model"].strip()

# Ensure the base_url ends with /api if provided
if "base_url" in normalized_data:
base_url = normalized_data["base_url"].rstrip("/")
if not base_url.endswith("/api"):
normalized_data["base_url"] = f"{base_url}/api"
# In Ollama force the stream to be True. Continue is not setting this parameter and
# most of our functionality is for streaming completions.
normalized_data["stream"] = True

return ChatCompletionRequest(**normalized_data)

Expand All @@ -42,18 +38,98 @@ def denormalize(self, data: ChatCompletionRequest) -> Dict:
return data


class OLlamaToModel(AsyncIterator[ModelResponse]):
def __init__(self, ollama_response: AsyncIterator[ChatResponse]):
self.ollama_response = ollama_response
self._aiter = ollama_response.__aiter__()

def __aiter__(self):
return self

async def __anext__(self):
try:
chunk = await self._aiter.__anext__()
if not isinstance(chunk, ChatResponse):
return chunk

finish_reason = None
role = "assistant"

# Convert the datetime object to a timestamp in seconds
datetime_obj = datetime.fromisoformat(chunk.created_at)
timestamp_seconds = int(datetime_obj.timestamp())

if chunk.done:
finish_reason = "stop"
role = None

model_response = ModelResponse(
id=f"ollama-chat-{str(uuid.uuid4())}",
created=timestamp_seconds,
model=chunk.model,
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=finish_reason,
index=0,
delta=Delta(content=chunk.message.content, role=role),
logprobs=None,
)
],
)
return model_response
except StopAsyncIteration:
raise StopAsyncIteration


class ModelToOllama(AsyncIterator[ChatResponse]):

def __init__(self, normalized_reply: AsyncIterator[ModelResponse]):
self.normalized_reply = normalized_reply
self._aiter = normalized_reply.__aiter__()

def __aiter__(self):
return self

async def __anext__(self) -> Union[ChatResponse]:
try:
chunk = await self._aiter.__anext__()
if not isinstance(chunk, ModelResponse):
return chunk
# Convert the timestamp to a datetime object
datetime_obj = datetime.fromtimestamp(chunk.created, tz=timezone.utc)
created_at = datetime_obj.isoformat()

message = chunk.choices[0].delta.content
done = False
if chunk.choices[0].finish_reason == "stop":
done = True
message = ""

# Convert the model response to an Ollama response
ollama_response = ChatResponse(
model=chunk.model,
created_at=created_at,
done=done,
message=Message(content=message, role="assistant"),
)
return ollama_response
except StopAsyncIteration:
raise StopAsyncIteration


class OllamaOutputNormalizer(ModelOutputNormalizer):
def __init__(self):
super().__init__()

def normalize_streaming(
self,
model_reply: Any,
) -> Any:
model_reply: AsyncIterator[ChatResponse],
) -> AsyncIterator[ModelResponse]:
"""
Pass through Ollama response
"""
return model_reply
return OLlamaToModel(model_reply)

def normalize(self, model_reply: Any) -> Any:
"""
Expand All @@ -68,10 +144,9 @@ def denormalize(self, normalized_reply: Any) -> Any:
return normalized_reply

def denormalize_streaming(
self,
normalized_reply: Any,
) -> Any:
self, normalized_reply: AsyncIterator[ModelResponse]
) -> AsyncIterator[ChatResponse]:
"""
Pass through Ollama response
"""
return normalized_reply
return ModelToOllama(normalized_reply)
Loading
Loading