1
1
from typing import Any , ClassVar , Dict , List , Optional , Tuple
2
2
3
3
import numpy as np
4
- from pydantic import ConfigDict
4
+ from pydantic import ConfigDict , Field
5
5
6
6
from semantic_router .index .base import BaseIndex , IndexConfig
7
7
from semantic_router .linear import similarity_matrix , top_scores
11
11
12
12
class LocalIndex (BaseIndex ):
13
13
type : str = "local"
14
+ metadata : Optional [np .ndarray ] = Field (default = None , exclude = True )
14
15
15
16
def __init__ (self , ** data ):
16
17
super ().__init__ (** data )
18
+ if self .metadata is None :
19
+ self .metadata = None
17
20
18
21
# Stop pydantic from complaining about Optional[np.ndarray]type hints.
19
22
model_config : ClassVar [ConfigDict ] = ConfigDict (arbitrary_types_allowed = True )
@@ -50,10 +53,30 @@ def add(
50
53
self .index = embeds # type: ignore
51
54
self .routes = routes_arr
52
55
self .utterances = utterances_arr
56
+ self .metadata = (
57
+ np .array (metadata_list , dtype = object )
58
+ if metadata_list
59
+ else np .array ([{} for _ in utterances ], dtype = object )
60
+ )
53
61
else :
54
62
self .index = np .concatenate ([self .index , embeds ])
55
63
self .routes = np .concatenate ([self .routes , routes_arr ])
56
64
self .utterances = np .concatenate ([self .utterances , utterances_arr ])
65
+ if self .metadata is not None :
66
+ self .metadata = np .concatenate (
67
+ [
68
+ self .metadata ,
69
+ np .array (metadata_list , dtype = object )
70
+ if metadata_list
71
+ else np .array ([{} for _ in utterances ], dtype = object ),
72
+ ]
73
+ )
74
+ else :
75
+ self .metadata = (
76
+ np .array (metadata_list , dtype = object )
77
+ if metadata_list
78
+ else np .array ([{} for _ in utterances ], dtype = object )
79
+ )
57
80
58
81
def _remove_and_sync (self , routes_to_delete : dict ) -> np .ndarray :
59
82
"""Remove and sync the index.
@@ -80,21 +103,35 @@ def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray:
80
103
self .index = self .index [mask ]
81
104
self .routes = self .routes [mask ]
82
105
self .utterances = self .utterances [mask ]
106
+ if self .metadata is not None :
107
+ self .metadata = self .metadata [mask ]
83
108
# return what was removed
84
109
return route_utterances [~ mask ]
85
110
86
111
def get_utterances (self , include_metadata : bool = False ) -> List [Utterance ]:
87
112
"""Gets a list of route and utterance objects currently stored in the index.
88
113
89
114
:param include_metadata: Whether to include function schemas and metadata in
90
- the returned Utterance objects - LocalIndex doesn't include metadata so this
91
- parameter is ignored.
115
+ the returned Utterance objects - LocalIndex now includes metadata if present.
92
116
:return: A list of Utterance objects.
93
117
:rtype: List[Utterance]
94
118
"""
95
119
if self .routes is None or self .utterances is None :
96
120
return []
97
- return [Utterance .from_tuple (x ) for x in zip (self .routes , self .utterances )]
121
+ if include_metadata and self .metadata is not None :
122
+ return [
123
+ Utterance (
124
+ route = route ,
125
+ utterance = utterance ,
126
+ function_schemas = None ,
127
+ metadata = metadata ,
128
+ )
129
+ for route , utterance , metadata in zip (
130
+ self .routes , self .utterances , self .metadata
131
+ )
132
+ ]
133
+ else :
134
+ return [Utterance .from_tuple (x ) for x in zip (self .routes , self .utterances )]
98
135
99
136
def describe (self ) -> IndexConfig :
100
137
"""Describe the index.
@@ -235,6 +272,8 @@ def delete(self, route_name: str):
235
272
self .index = np .delete (self .index , delete_idx , axis = 0 )
236
273
self .routes = np .delete (self .routes , delete_idx , axis = 0 )
237
274
self .utterances = np .delete (self .utterances , delete_idx , axis = 0 )
275
+ if self .metadata is not None :
276
+ self .metadata = np .delete (self .metadata , delete_idx , axis = 0 )
238
277
else :
239
278
raise ValueError (
240
279
"Attempted to delete route records but either index, routes or "
@@ -260,6 +299,7 @@ def delete_index(self):
260
299
self .index = None
261
300
self .routes = None
262
301
self .utterances = None
302
+ self .metadata = None
263
303
264
304
async def adelete_index (self ):
265
305
"""Deletes the index, effectively clearing it and setting it to None. Note that this just points
@@ -272,6 +312,7 @@ async def adelete_index(self):
272
312
self .index = None
273
313
self .routes = None
274
314
self .utterances = None
315
+ self .metadata = None
275
316
276
317
def _get_indices_for_route (self , route_name : str ):
277
318
"""Gets an array of indices for a specific route.
0 commit comments