Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit ef28395

Browse files
committed
feat: enable weaviate usage in codegate
Closes: #63
1 parent dc988a3 commit ef28395

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

src/codegate/storage/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.storage.storage_engine import StorageEngine
2+
3+
__all__ = [StorageEngine]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from codegate.codegate_logging import setup_logging
2+
from codegate.inference.inference_engine import LlamaCppInferenceEngine
3+
from weaviate.classes.config import DataType, Property
4+
from weaviate.classes.query import MetadataQuery
5+
import weaviate
6+
7+
8+
schema_config = [
9+
{
10+
"name": "Package",
11+
"properties": [
12+
{"name": "name", "data_type": DataType.TEXT},
13+
{"name": "type", "data_type": DataType.TEXT},
14+
{"name": "status", "data_type": DataType.TEXT},
15+
{"name": "description", "data_type": DataType.TEXT},
16+
]
17+
},
18+
]
19+
20+
21+
class StorageEngine:
22+
def __init__(self, data_path='./weaviate_data'):
23+
self.client = weaviate.WeaviateClient(
24+
embedded_options=weaviate.EmbeddedOptions(
25+
persistence_data_path=data_path
26+
),
27+
)
28+
self.__logger = setup_logging()
29+
self.inference_engine = LlamaCppInferenceEngine()
30+
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
31+
self.schema_config = schema_config
32+
self.connect()
33+
self.setup_schema()
34+
35+
def connect(self):
36+
self.client.connect()
37+
if self.client.is_ready():
38+
self.__logger.info("Weaviate connection established and client is ready.")
39+
else:
40+
raise Exception("Weaviate client is not ready.")
41+
42+
def setup_schema(self):
43+
for class_config in self.schema_config:
44+
if not self.client.collections.exists(class_config['name']):
45+
self.client.collections.create(class_config['name'], properties=class_config['properties'])
46+
self.__logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")
47+
48+
async def search(self, query, limit=5, distance=0.1):
49+
"""
50+
Search the 'Package' collection based on a query string.
51+
52+
Args:
53+
query (str): The text query for which to search.
54+
limit (int): The number of results to return.
55+
56+
Returns:
57+
list: A list of matching results with their properties and distances.
58+
"""
59+
# Generate the vector for the query
60+
query_vector = await self.inference_engine.embed(self.model_path, [query])
61+
62+
# Perform the vector search
63+
try:
64+
collection = self.client.collections.get("Package")
65+
response = collection.query.near_vector(query_vector, limit=limit, distance=distance, return_metadata=MetadataQuery(distance=True))
66+
if not response:
67+
return []
68+
return response.objects
69+
70+
except Exception as e:
71+
self.__logger.error(f"Error during search: {str(e)}")
72+
return []
73+
74+
def close(self):
75+
self.client.close()
76+
self.__logger.info("Connection closed.")

tests/test_storage.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
from unittest.mock import Mock, AsyncMock
3+
from codegate.storage.storage_engine import StorageEngine # Adjust the import according to your project structure
4+
5+
6+
@pytest.fixture
7+
def mock_client():
8+
client = Mock()
9+
client.connect = Mock()
10+
client.is_ready = Mock(return_value=True)
11+
client.schema.contains = Mock(return_value=False)
12+
client.schema.create_class = Mock()
13+
client.collections.get = Mock()
14+
client.close = Mock()
15+
return client
16+
17+
18+
@pytest.fixture
19+
def mock_logger():
20+
logger = Mock()
21+
return logger
22+
23+
24+
@pytest.fixture
25+
def mock_inference_engine():
26+
inference_engine = AsyncMock()
27+
inference_engine.embed = AsyncMock(return_value=[0.1, 0.2, 0.3]) # Adjust based on expected vector dimensions
28+
return inference_engine
29+
30+
31+
@pytest.fixture
32+
def storage_engine(mock_client, mock_logger, mock_inference_engine):
33+
engine = StorageEngine(data_path='./weaviate_data')
34+
engine.client = mock_client
35+
engine.__logger = mock_logger
36+
engine.inference_engine = mock_inference_engine
37+
return engine
38+
39+
40+
def test_connect(storage_engine, mock_client):
41+
storage_engine.connect()
42+
mock_client.connect.assert_called_once()
43+
mock_client.is_ready.assert_called_once()
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_search(storage_engine, mock_client):
48+
query = "test query"
49+
results = await storage_engine.search(query)
50+
storage_engine.inference_engine.embed.assert_called_once_with("./models/all-minilm-L6-v2-q5_k_m.gguf", [query])
51+
assert results is not None # Further asserts can be based on your application logic
52+
53+
54+
def test_close(storage_engine, mock_client):
55+
storage_engine.close()
56+
mock_client.close.assert_called_once()

0 commit comments

Comments
 (0)