Skip to content

Commit 69b4932

Browse files
authored
Merge pull request #487 from aurelio-labs/james/async-sync
feat: async sync and pinecone methods
2 parents 2b9720f + 9915053 commit 69b4932

File tree

4 files changed

+795
-37
lines changed

4 files changed

+795
-37
lines changed

semantic_router/index/base.py

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from datetime import datetime
23
import time
34
from typing import Any, List, Optional, Tuple, Union, Dict
@@ -11,6 +12,9 @@
1112
from semantic_router.utils.logger import logger
1213

1314

15+
RETRY_WAIT_TIME = 2.5
16+
17+
1418
class BaseIndex(BaseModel):
1519
"""
1620
Base class for indices using Pydantic's BaseModel.
@@ -35,12 +39,31 @@ def add(
3539
function_schemas: Optional[List[Dict[str, Any]]] = None,
3640
metadata_list: List[Dict[str, Any]] = [],
3741
):
38-
"""
39-
Add embeddings to the index.
42+
"""Add embeddings to the index.
4043
This method should be implemented by subclasses.
4144
"""
4245
raise NotImplementedError("This method should be implemented by subclasses.")
4346

47+
async def aadd(
48+
self,
49+
embeddings: List[List[float]],
50+
routes: List[str],
51+
utterances: List[str],
52+
function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
53+
metadata_list: List[Dict[str, Any]] = [],
54+
):
55+
"""Add vectors to the index asynchronously.
56+
This method should be implemented by subclasses.
57+
"""
58+
logger.warning("Async method not implemented.")
59+
return self.add(
60+
embeddings=embeddings,
61+
routes=routes,
62+
utterances=utterances,
63+
function_schemas=function_schemas,
64+
metadata_list=metadata_list,
65+
)
66+
4467
def get_utterances(self) -> List[Utterance]:
4568
"""Gets a list of route and utterance objects currently stored in the
4669
index, including additional metadata.
@@ -56,6 +79,21 @@ def get_utterances(self) -> List[Utterance]:
5679
route_tuples = parse_route_info(metadata=metadata)
5780
return [Utterance.from_tuple(x) for x in route_tuples]
5881

82+
async def aget_utterances(self) -> List[Utterance]:
83+
"""Gets a list of route and utterance objects currently stored in the
84+
index, including additional metadata.
85+
86+
:return: A list of tuples, each containing route, utterance, function
87+
schema and additional metadata.
88+
:rtype: List[Tuple]
89+
"""
90+
if self.index is None:
91+
logger.warning("Index is None, could not retrieve utterances.")
92+
return []
93+
_, metadata = await self._async_get_all(include_metadata=True)
94+
route_tuples = parse_route_info(metadata=metadata)
95+
return [Utterance.from_tuple(x) for x in route_tuples]
96+
5997
def get_routes(self) -> List[Route]:
6098
"""Gets a list of route objects currently stored in the index.
6199
@@ -90,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict):
90128
"""
91129
raise NotImplementedError("This method should be implemented by subclasses.")
92130

131+
async def _async_remove_and_sync(self, routes_to_delete: dict):
132+
"""
133+
Remove embeddings in a routes syncing process from the index asynchronously.
134+
This method should be implemented by subclasses.
135+
"""
136+
logger.warning("Async method not implemented.")
137+
return self._remove_and_sync(routes_to_delete=routes_to_delete)
138+
93139
def delete(self, route_name: str):
94140
"""
95141
Deletes route by route name.
@@ -159,6 +205,10 @@ def delete_index(self):
159205
logger.warning("This method should be implemented by subclasses.")
160206
self.index = None
161207

208+
# ___________________________ CONFIG ___________________________
209+
# When implementing a new index, the following methods should be implemented
210+
# to enable synchronization of remote indexes.
211+
162212
def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
163213
"""Read a config parameter from the index.
164214
@@ -176,13 +226,20 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
176226
scope=scope,
177227
)
178228

179-
def _read_hash(self) -> ConfigParameter:
180-
"""Read the hash of the previously written index.
229+
async def _async_read_config(
230+
self, field: str, scope: str | None = None
231+
) -> ConfigParameter:
232+
"""Read a config parameter from the index asynchronously.
181233
234+
:param field: The field to read.
235+
:type field: str
236+
:param scope: The scope to read.
237+
:type scope: str | None
182238
:return: The config parameter that was read.
183239
:rtype: ConfigParameter
184240
"""
185-
return self._read_config(field="sr_hash")
241+
logger.warning("Async method not implemented.")
242+
return self._read_config(field=field, scope=scope)
186243

187244
def _write_config(self, config: ConfigParameter) -> ConfigParameter:
188245
"""Write a config parameter to the index.
@@ -195,6 +252,67 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter:
195252
logger.warning("This method should be implemented by subclasses.")
196253
return config
197254

255+
async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter:
256+
"""Write a config parameter to the index asynchronously.
257+
258+
:param config: The config parameter to write.
259+
:type config: ConfigParameter
260+
:return: The config parameter that was written.
261+
:rtype: ConfigParameter
262+
"""
263+
logger.warning("Async method not implemented.")
264+
return self._write_config(config=config)
265+
266+
# _________________________ END CONFIG _________________________
267+
268+
def _read_hash(self) -> ConfigParameter:
269+
"""Read the hash of the previously written index.
270+
271+
:return: The config parameter that was read.
272+
:rtype: ConfigParameter
273+
"""
274+
return self._read_config(field="sr_hash")
275+
276+
async def _async_read_hash(self) -> ConfigParameter:
277+
"""Read the hash of the previously written index asynchronously.
278+
279+
:return: The config parameter that was read.
280+
:rtype: ConfigParameter
281+
"""
282+
return await self._async_read_config(field="sr_hash")
283+
284+
def _is_locked(self, scope: str | None = None) -> bool:
285+
"""Check if the index is locked for a given scope (if applicable).
286+
287+
:param scope: The scope to check.
288+
:type scope: str | None
289+
:return: True if the index is locked, False otherwise.
290+
:rtype: bool
291+
"""
292+
lock_config = self._read_config(field="sr_lock", scope=scope)
293+
if lock_config.value == "True":
294+
return True
295+
elif lock_config.value == "False" or not lock_config.value:
296+
return False
297+
else:
298+
raise ValueError(f"Invalid lock value: {lock_config.value}")
299+
300+
async def _ais_locked(self, scope: str | None = None) -> bool:
301+
"""Check if the index is locked for a given scope (if applicable).
302+
303+
:param scope: The scope to check.
304+
:type scope: str | None
305+
:return: True if the index is locked, False otherwise.
306+
:rtype: bool
307+
"""
308+
lock_config = await self._async_read_config(field="sr_lock", scope=scope)
309+
if lock_config.value == "True":
310+
return True
311+
elif lock_config.value == "False" or not lock_config.value:
312+
return False
313+
else:
314+
raise ValueError(f"Invalid lock value: {lock_config.value}")
315+
198316
def lock(
199317
self, value: bool, wait: int = 0, scope: str | None = None
200318
) -> ConfigParameter:
@@ -215,8 +333,8 @@ def lock(
215333
# in this case, we can set the lock value
216334
break
217335
if (datetime.now() - start_time).total_seconds() < wait:
218-
# wait for 2.5 seconds before checking again
219-
time.sleep(2.5)
336+
# wait for a few seconds before checking again
337+
time.sleep(RETRY_WAIT_TIME)
220338
else:
221339
raise ValueError(
222340
f"Index is already {'locked' if value else 'unlocked'}."
@@ -229,21 +347,31 @@ def lock(
229347
self._write_config(lock_param)
230348
return lock_param
231349

232-
def _is_locked(self, scope: str | None = None) -> bool:
233-
"""Check if the index is locked for a given scope (if applicable).
234-
235-
:param scope: The scope to check.
236-
:type scope: str | None
237-
:return: True if the index is locked, False otherwise.
238-
:rtype: bool
350+
async def alock(
351+
self, value: bool, wait: int = 0, scope: str | None = None
352+
) -> ConfigParameter:
353+
"""Lock/unlock the index for a given scope (if applicable). If index
354+
already locked/unlocked, raises ValueError.
239355
"""
240-
lock_config = self._read_config(field="sr_lock", scope=scope)
241-
if lock_config.value == "True":
242-
return True
243-
elif lock_config.value == "False" or not lock_config.value:
244-
return False
245-
else:
246-
raise ValueError(f"Invalid lock value: {lock_config.value}")
356+
start_time = datetime.now()
357+
while True:
358+
if await self._ais_locked(scope=scope) != value:
359+
# in this case, we can set the lock value
360+
break
361+
if (datetime.now() - start_time).total_seconds() < wait:
362+
# wait for a few seconds before checking again
363+
await asyncio.sleep(RETRY_WAIT_TIME)
364+
else:
365+
raise ValueError(
366+
f"Index is already {'locked' if value else 'unlocked'}."
367+
)
368+
lock_param = ConfigParameter(
369+
field="sr_lock",
370+
value=str(value),
371+
scope=scope,
372+
)
373+
await self._async_write_config(lock_param)
374+
return lock_param
247375

248376
def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
249377
"""

0 commit comments

Comments
 (0)