Skip to content

Commit 0c1f337

Browse files
authored
Merge pull request #625 from aurelio-labs/josh/semantic-router/local-index/metadata
feat: added metadata attribute to local index
2 parents ff37bf8 + baf3609 commit 0c1f337

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

semantic_router/encoders/openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
137137
# Exponential backoff
138138
for j in range(self.max_retries + 1):
139139
try:
140+
logger.debug(f"Creating embeddings for {len(docs)} docs")
140141
embeds = self._client.embeddings.create(
141142
input=docs,
142143
model=self.name,

semantic_router/index/hybrid_local.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class HybridLocalIndex(LocalIndex):
1515

1616
def __init__(self, **data):
1717
super().__init__(**data)
18+
self.metadata = None
1819

1920
def add(
2021
self,
@@ -64,12 +65,32 @@ def add(
6465
] # TODO: switch back to using SparseEmbedding later
6566
self.routes = routes_arr
6667
self.utterances = utterances_arr
68+
self.metadata = (
69+
np.array(metadata_list, dtype=object)
70+
if metadata_list
71+
else np.array([{} for _ in utterances], dtype=object)
72+
)
6773
else:
6874
# TODO: we should probably switch to an `upsert` method and standardize elsewhere
6975
self.index = np.concatenate([self.index, embeds])
7076
self.sparse_index.extend([x.to_dict() for x in sparse_embeddings])
7177
self.routes = np.concatenate([self.routes, routes_arr])
7278
self.utterances = np.concatenate([self.utterances, utterances_arr])
79+
if self.metadata is not None:
80+
self.metadata = np.concatenate(
81+
[
82+
self.metadata,
83+
np.array(metadata_list, dtype=object)
84+
if metadata_list
85+
else np.array([{} for _ in utterances], dtype=object),
86+
]
87+
)
88+
else:
89+
self.metadata = (
90+
np.array(metadata_list, dtype=object)
91+
if metadata_list
92+
else np.array([{} for _ in utterances], dtype=object)
93+
)
7394

7495
async def aadd(
7596
self,
@@ -119,7 +140,20 @@ def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
119140
"""
120141
if self.routes is None or self.utterances is None:
121142
return []
122-
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
143+
if include_metadata and self.metadata is not None:
144+
return [
145+
Utterance(
146+
route=route,
147+
utterance=utterance,
148+
function_schemas=None,
149+
metadata=metadata,
150+
)
151+
for route, utterance, metadata in zip(
152+
self.routes, self.utterances, self.metadata
153+
)
154+
]
155+
else:
156+
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
123157

124158
def _sparse_dot_product(
125159
self, vec_a: dict[int, float], vec_b: dict[int, float]
@@ -260,6 +294,8 @@ def delete(self, route_name: str):
260294
self.index = np.delete(self.index, delete_idx, axis=0)
261295
self.routes = np.delete(self.routes, delete_idx, axis=0)
262296
self.utterances = np.delete(self.utterances, delete_idx, axis=0)
297+
if self.metadata is not None:
298+
self.metadata = np.delete(self.metadata, delete_idx, axis=0)
263299
else:
264300
raise ValueError(
265301
"Attempted to delete route records but either index, routes or "
@@ -275,6 +311,7 @@ def delete_index(self):
275311
self.index = None
276312
self.routes = None
277313
self.utterances = None
314+
self.metadata = None
278315

279316
def _get_indices_for_route(self, route_name: str):
280317
"""Gets an array of indices for a specific route.

semantic_router/index/local.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, ClassVar, Dict, List, Optional, Tuple
22

33
import numpy as np
4-
from pydantic import ConfigDict
4+
from pydantic import ConfigDict, Field
55

66
from semantic_router.index.base import BaseIndex, IndexConfig
77
from semantic_router.linear import similarity_matrix, top_scores
@@ -11,9 +11,12 @@
1111

1212
class LocalIndex(BaseIndex):
1313
type: str = "local"
14+
metadata: Optional[np.ndarray] = Field(default=None, exclude=True)
1415

1516
def __init__(self, **data):
1617
super().__init__(**data)
18+
if self.metadata is None:
19+
self.metadata = None
1720

1821
# Stop pydantic from complaining about Optional[np.ndarray]type hints.
1922
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
@@ -50,10 +53,30 @@ def add(
5053
self.index = embeds # type: ignore
5154
self.routes = routes_arr
5255
self.utterances = utterances_arr
56+
self.metadata = (
57+
np.array(metadata_list, dtype=object)
58+
if metadata_list
59+
else np.array([{} for _ in utterances], dtype=object)
60+
)
5361
else:
5462
self.index = np.concatenate([self.index, embeds])
5563
self.routes = np.concatenate([self.routes, routes_arr])
5664
self.utterances = np.concatenate([self.utterances, utterances_arr])
65+
if self.metadata is not None:
66+
self.metadata = np.concatenate(
67+
[
68+
self.metadata,
69+
np.array(metadata_list, dtype=object)
70+
if metadata_list
71+
else np.array([{} for _ in utterances], dtype=object),
72+
]
73+
)
74+
else:
75+
self.metadata = (
76+
np.array(metadata_list, dtype=object)
77+
if metadata_list
78+
else np.array([{} for _ in utterances], dtype=object)
79+
)
5780

5881
def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray:
5982
"""Remove and sync the index.
@@ -80,21 +103,35 @@ def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray:
80103
self.index = self.index[mask]
81104
self.routes = self.routes[mask]
82105
self.utterances = self.utterances[mask]
106+
if self.metadata is not None:
107+
self.metadata = self.metadata[mask]
83108
# return what was removed
84109
return route_utterances[~mask]
85110

86111
def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
87112
"""Gets a list of route and utterance objects currently stored in the index.
88113
89114
:param include_metadata: Whether to include function schemas and metadata in
90-
the returned Utterance objects - LocalIndex doesn't include metadata so this
91-
parameter is ignored.
115+
the returned Utterance objects - LocalIndex now includes metadata if present.
92116
:return: A list of Utterance objects.
93117
:rtype: List[Utterance]
94118
"""
95119
if self.routes is None or self.utterances is None:
96120
return []
97-
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
121+
if include_metadata and self.metadata is not None:
122+
return [
123+
Utterance(
124+
route=route,
125+
utterance=utterance,
126+
function_schemas=None,
127+
metadata=metadata,
128+
)
129+
for route, utterance, metadata in zip(
130+
self.routes, self.utterances, self.metadata
131+
)
132+
]
133+
else:
134+
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
98135

99136
def describe(self) -> IndexConfig:
100137
"""Describe the index.
@@ -235,6 +272,8 @@ def delete(self, route_name: str):
235272
self.index = np.delete(self.index, delete_idx, axis=0)
236273
self.routes = np.delete(self.routes, delete_idx, axis=0)
237274
self.utterances = np.delete(self.utterances, delete_idx, axis=0)
275+
if self.metadata is not None:
276+
self.metadata = np.delete(self.metadata, delete_idx, axis=0)
238277
else:
239278
raise ValueError(
240279
"Attempted to delete route records but either index, routes or "
@@ -260,6 +299,7 @@ def delete_index(self):
260299
self.index = None
261300
self.routes = None
262301
self.utterances = None
302+
self.metadata = None
263303

264304
async def adelete_index(self):
265305
"""Deletes the index, effectively clearing it and setting it to None. Note that this just points
@@ -272,6 +312,7 @@ async def adelete_index(self):
272312
self.index = None
273313
self.routes = None
274314
self.utterances = None
315+
self.metadata = None
275316

276317
def _get_indices_for_route(self, route_name: str):
277318
"""Gets an array of indices for a specific route.

semantic_router/utils/logger.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ def setup_custom_logger(name):
4343
)
4444
log_level = log_level.upper()
4545

46-
if not logger.hasHandlers():
47-
add_coloured_handler(logger)
48-
logger.setLevel(log_level)
49-
logger.propagate = False
46+
add_coloured_handler(logger)
47+
logger.setLevel(log_level)
48+
logger.propagate = False
5049

5150
return logger
5251

5352

54-
logger: logging.Logger = setup_custom_logger(__name__)
53+
logger: logging.Logger = setup_custom_logger("semantic_router")

0 commit comments

Comments
 (0)