Skip to content

Commit 567e4c9

Browse files
authored
Merge pull request #397 from aurelio-labs/vittorio/add-async-get-routes-method-to-pinecone-index
feat: Implemented aget_routes async method for pinecone index
2 parents db451eb + 7f2cfae commit 567e4c9

File tree

9 files changed

+137
-9
lines changed

9 files changed

+137
-9
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
project = "Semantic Router"
1616
copyright = "2024, Aurelio AI"
1717
author = "Aurelio AI"
18-
release = "0.0.60"
18+
release = "0.0.61"
1919

2020
# -- General configuration ---------------------------------------------------
2121
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "semantic-router"
3-
version = "0.0.60"
3+
version = "0.0.61"
44
description = "Super fast semantic router for AI decision making"
55
authors = [
66
"James Briggs <[email protected]>",

semantic_router/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]
66

7-
__version__ = "0.0.60"
7+
__version__ = "0.0.61"

semantic_router/index/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ async def aquery(
9090
"""
9191
raise NotImplementedError("This method should be implemented by subclasses.")
9292

93+
def aget_routes(self):
94+
"""
95+
Asynchronously get a list of route and utterance objects currently stored in the index.
96+
This method should be implemented by subclasses.
97+
98+
:returns: A list of tuples, each containing a route name and an associated utterance.
99+
:rtype: list[tuple]
100+
:raises NotImplementedError: If the method is not implemented by the subclass.
101+
"""
102+
raise NotImplementedError("This method should be implemented by subclasses.")
103+
93104
def delete_index(self):
94105
"""
95106
Deletes or resets the index.

semantic_router/index/local.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ async def aquery(
128128
route_names = [self.routes[i] for i in idx]
129129
return scores, route_names
130130

131+
def aget_routes(self):
132+
logger.error("Sync remove is not implemented for LocalIndex.")
133+
131134
def delete(self, route_name: str):
132135
"""
133136
Delete all records of a specific route from the index.

semantic_router/index/pinecone.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,18 @@ async def aquery(
528528
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
529529
return np.array(scores), route_names
530530

531+
async def aget_routes(self) -> list[tuple]:
532+
"""
533+
Asynchronously get a list of route and utterance objects currently stored in the index.
534+
535+
Returns:
536+
List[Tuple]: A list of (route_name, utterance) objects.
537+
"""
538+
if self.async_client is None or self.host is None:
539+
raise ValueError("Async client or host are not initialized.")
540+
541+
return await self._async_get_routes()
542+
531543
def delete_index(self):
532544
self.client.delete_index(self.index_name)
533545

@@ -584,5 +596,101 @@ async def _async_describe_index(self, name: str):
584596
async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
585597
return await response.json(content_type=None)
586598

599+
async def _async_get_all(
600+
self, prefix: Optional[str] = None, include_metadata: bool = False
601+
) -> tuple[list[str], list[dict]]:
602+
"""
603+
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
604+
"""
605+
if self.index is None:
606+
raise ValueError("Index is None, could not retrieve vector IDs.")
607+
608+
all_vector_ids = []
609+
next_page_token = None
610+
611+
if prefix:
612+
prefix_str = f"?prefix={prefix}"
613+
else:
614+
prefix_str = ""
615+
616+
list_url = f"https://{self.host}/vectors/list{prefix_str}"
617+
params: dict = {}
618+
if self.namespace:
619+
params["namespace"] = self.namespace
620+
metadata = []
621+
622+
while True:
623+
if next_page_token:
624+
params["paginationToken"] = next_page_token
625+
626+
async with self.async_client.get(
627+
list_url, params=params, headers={"Api-Key": self.api_key}
628+
) as response:
629+
if response.status != 200:
630+
error_text = await response.text()
631+
logger.error(f"Error fetching vectors: {error_text}")
632+
break
633+
634+
response_data = await response.json(content_type=None)
635+
636+
vector_ids = [vec["id"] for vec in response_data.get("vectors", [])]
637+
if not vector_ids:
638+
break
639+
all_vector_ids.extend(vector_ids)
640+
641+
if include_metadata:
642+
metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids]
643+
metadata_results = await asyncio.gather(*metadata_tasks)
644+
metadata.extend(metadata_results)
645+
646+
next_page_token = response_data.get("pagination", {}).get("next")
647+
if not next_page_token:
648+
break
649+
650+
return all_vector_ids, metadata
651+
652+
async def _async_fetch_metadata(self, vector_id: str) -> dict:
653+
"""
654+
Fetch metadata for a single vector ID asynchronously using the async_client.
655+
"""
656+
url = f"https://{self.host}/vectors/fetch"
657+
658+
params = {
659+
"ids": [vector_id],
660+
}
661+
662+
headers = {
663+
"Api-Key": self.api_key,
664+
}
665+
666+
async with self.async_client.get(
667+
url, params=params, headers=headers
668+
) as response:
669+
if response.status != 200:
670+
error_text = await response.text()
671+
logger.error(f"Error fetching metadata: {error_text}")
672+
return {}
673+
674+
try:
675+
response_data = await response.json(content_type=None)
676+
except Exception as e:
677+
logger.warning(f"No metadata found for vector {vector_id}: {e}")
678+
return {}
679+
680+
return (
681+
response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
682+
)
683+
684+
async def _async_get_routes(self) -> list[tuple]:
685+
"""
686+
Gets a list of route and utterance objects currently stored in the index.
687+
688+
Returns:
689+
List[Tuple]: A list of (route_name, utterance) objects.
690+
"""
691+
_, metadata = await self._async_get_all(include_metadata=True)
692+
route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
693+
return route_tuples
694+
587695
def __len__(self):
588696
return self.index.describe_index_stats()["total_vector_count"]

semantic_router/index/postgres.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from semantic_router.index.base import BaseIndex
1111
from semantic_router.schema import Metric
12+
from semantic_router.utils.logger import logger
1213

1314

1415
class MetricPgVecOperatorMap(Enum):
@@ -456,6 +457,9 @@ def delete_index(self) -> None:
456457
cur.execute(f"DROP TABLE IF EXISTS {table_name}")
457458
self.conn.commit()
458459

460+
def aget_routes(self):
461+
logger.error("Sync remove is not implemented for PostgresIndex.")
462+
459463
def __len__(self):
460464
"""
461465
Returns the total number of vectors in the index.

semantic_router/index/qdrant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ async def aquery(
317317
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
318318
return np.array(scores), route_names
319319

320+
def aget_routes(self):
321+
logger.error("Sync remove is not implemented for QdrantIndex.")
322+
320323
def delete_index(self):
321324
self.client.delete_collection(self.index_name)
322325

tests/unit/encoders/test_vit.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from semantic_router.encoders import VitEncoder
88

99
test_model_name = "aurelio-ai/sr-test-vit"
10-
vit_encoder = VitEncoder(name=test_model_name)
1110
embed_dim = 32
1211

1312
if torch.cuda.is_available():
@@ -44,15 +43,11 @@ def test_vit_encoder__import_errors_torch(self, mocker):
4443
with pytest.raises(ImportError):
4544
VitEncoder()
4645

47-
def test_vit_encoder__import_errors_torchvision(self, mocker):
48-
mocker.patch.dict("sys.modules", {"torchvision": None})
49-
with pytest.raises(ImportError):
50-
VitEncoder()
51-
5246
@pytest.mark.skipif(
5347
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
5448
)
5549
def test_vit_encoder_initialization(self):
50+
vit_encoder = VitEncoder(name=test_model_name)
5651
assert vit_encoder.name == test_model_name
5752
assert vit_encoder.type == "huggingface"
5853
assert vit_encoder.score_threshold == 0.5
@@ -62,6 +57,7 @@ def test_vit_encoder_initialization(self):
6257
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
6358
)
6459
def test_vit_encoder_call(self, dummy_pil_image):
60+
vit_encoder = VitEncoder(name=test_model_name)
6561
encoded_images = vit_encoder([dummy_pil_image] * 3)
6662

6763
assert len(encoded_images) == 3
@@ -71,6 +67,7 @@ def test_vit_encoder_call(self, dummy_pil_image):
7167
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
7268
)
7369
def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
70+
vit_encoder = VitEncoder(name=test_model_name)
7471
encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])
7572

7673
assert len(encoded_images) == 2
@@ -80,6 +77,7 @@ def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
8077
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
8178
)
8279
def test_vit_encoder_process_images_device(self, dummy_pil_image):
80+
vit_encoder = VitEncoder(name=test_model_name)
8381
imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]
8482

8583
assert imgs.device.type == device
@@ -88,6 +86,7 @@ def test_vit_encoder_process_images_device(self, dummy_pil_image):
8886
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
8987
)
9088
def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img):
89+
vit_encoder = VitEncoder(name=test_model_name)
9190
rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img)
9291

9392
assert rgb_image.mode == "RGB"

0 commit comments

Comments
 (0)