Skip to content

Commit 03825fc

Browse files
authored
Merge pull request #592 from aurelio-labs/james/async-limit
fix: add limit to acall and fix async hybrid router
2 parents b46a305 + be2e196 commit 03825fc

File tree

3 files changed

+265
-3
lines changed

3 files changed

+265
-3
lines changed

semantic_router/routers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ def _pass_routes(
639639
route.llm = self.llm
640640
# call dynamic route to generate the function_call content
641641
route_choice = route(query=text)
642+
if route_choice is not None and route_choice.similarity_score is None:
643+
route_choice.similarity_score = total_score
642644
passed_routes.append(route_choice)
643645
elif passed and route is not None and simulate_static:
644646
passed_routes.append(
@@ -757,6 +759,7 @@ async def acall(
757759
self,
758760
text: Optional[str] = None,
759761
vector: Optional[List[float] | np.ndarray] = None,
762+
limit: int | None = 1,
760763
simulate_static: bool = False,
761764
route_filter: Optional[List[str]] = None,
762765
) -> RouteChoice | list[RouteChoice]:
@@ -796,7 +799,7 @@ async def acall(
796799
scored_routes=scored_routes,
797800
simulate_static=simulate_static,
798801
text=text,
799-
limit=1,
802+
limit=limit,
800803
)
801804

802805
def _index_ready(self) -> bool:

semantic_router/routers/hybrid.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def __call__(
319319
if not self.index.is_ready():
320320
raise ValueError("Index is not ready.")
321321
if self.sparse_encoder is None:
322-
raise ValueError
322+
raise ValueError("Sparse encoder is not set.")
323323
potential_sparse_vector: List[SparseEmbedding] | None = None
324324
# if no vector provided, encode text to get vector
325325
if vector is None:
@@ -360,6 +360,70 @@ def __call__(
360360
)
361361
return route_choices
362362

363+
async def acall(
364+
self,
365+
text: Optional[str] = None,
366+
vector: Optional[List[float] | np.ndarray] = None,
367+
limit: int | None = 1,
368+
simulate_static: bool = False,
369+
route_filter: Optional[List[str]] = None,
370+
sparse_vector: dict[int, float] | SparseEmbedding | None = None,
371+
) -> RouteChoice | list[RouteChoice]:
372+
"""Asynchronously call the router to get a route choice.
373+
374+
:param text: The text to route.
375+
:type text: Optional[str]
376+
:param vector: The vector to route.
377+
:type vector: Optional[List[float] | np.ndarray]
378+
:param simulate_static: Whether to simulate a static route (ie avoid dynamic route
379+
LLM calls during fit or evaluate).
380+
:type simulate_static: bool
381+
:param route_filter: The route filter to use.
382+
:type route_filter: Optional[List[str]]
383+
:param sparse_vector: The sparse vector to use.
384+
:type sparse_vector: dict[int, float] | SparseEmbedding | None
385+
:return: The route choice.
386+
:rtype: RouteChoice
387+
"""
388+
if not self.index.is_ready():
389+
# TODO: need async version for qdrant
390+
raise ValueError("Index is not ready.")
391+
if self.sparse_encoder is None:
392+
raise ValueError("Sparse encoder is not set.")
393+
potential_sparse_vector: List[SparseEmbedding] | None = None
394+
# if no vector provided, encode text to get vector
395+
if vector is None:
396+
if text is None:
397+
raise ValueError("Either text or vector must be provided")
398+
vector, potential_sparse_vector = await self._async_encode(
399+
text=[text], input_type="queries"
400+
)
401+
# convert to numpy array if not already
402+
vector = xq_reshape(xq=vector)
403+
if sparse_vector is None:
404+
if text is None:
405+
raise ValueError("Either text or sparse_vector must be provided")
406+
sparse_vector = (
407+
potential_sparse_vector[0] if potential_sparse_vector else None
408+
)
409+
# get scores and routes
410+
scores, routes = await self.index.aquery(
411+
vector=vector[0],
412+
top_k=self.top_k,
413+
route_filter=route_filter,
414+
sparse_vector=sparse_vector,
415+
)
416+
query_results = [
417+
{"route": d, "score": s.item()} for d, s in zip(routes, scores)
418+
]
419+
scored_routes = self._score_routes(query_results=query_results)
420+
return await self._async_pass_routes(
421+
scored_routes=scored_routes,
422+
simulate_static=simulate_static,
423+
text=text,
424+
limit=limit,
425+
)
426+
363427
def _convex_scaling(
364428
self, dense: np.ndarray, sparse: list[SparseEmbedding]
365429
) -> tuple[np.ndarray, list[SparseEmbedding]]:

tests/unit/test_router.py

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from semantic_router.llms import BaseLLM, OpenAILLM
2525
from semantic_router.route import Route
2626
from semantic_router.routers import HybridRouter, RouterConfig, SemanticRouter
27-
from semantic_router.schema import SparseEmbedding
27+
from semantic_router.schema import RouteChoice, SparseEmbedding
2828

2929
PINECONE_SLEEP = 8
3030
RETRY_COUNT = 10
@@ -206,6 +206,16 @@ def routes_4():
206206
]
207207

208208

209+
@pytest.fixture
210+
def routes_5():
211+
return [
212+
Route(name="Route 1", utterances=["Hello", "Hi"], metadata={"type": "default"}),
213+
Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]),
214+
Route(name="Route 3", utterances=["Hello", "Hi"]),
215+
Route(name="Route 4", utterances=["Goodbye", "Bye", "Au revoir"]),
216+
]
217+
218+
209219
@pytest.fixture
210220
def route_single_utterance():
211221
return [
@@ -643,3 +653,188 @@ async def test_async_encode(
643653
assert sparse_encode_queries_spy.called
644654
else:
645655
assert sparse_call_spy.called
656+
657+
658+
@pytest.mark.parametrize(
659+
"router_cls",
660+
[
661+
HybridRouter,
662+
SemanticRouter,
663+
],
664+
)
665+
class TestRouter:
666+
def test_query_parameter(self, router_cls, routes_5, mocker):
667+
"""Test that we return expected values in RouteChoice objects."""
668+
# Create router with mock encoders
669+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
670+
if router_cls == HybridRouter:
671+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
672+
router = router_cls(
673+
encoder=dense_encoder,
674+
sparse_encoder=sparse_encoder,
675+
routes=routes_5,
676+
auto_sync="local",
677+
)
678+
else:
679+
router = router_cls(
680+
encoder=dense_encoder,
681+
routes=routes_5,
682+
auto_sync="local",
683+
)
684+
685+
# Setup a mock for the similarity calculation method
686+
_ = mocker.patch.object(
687+
router,
688+
"_score_routes",
689+
return_value=[
690+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
691+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
692+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
693+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
694+
],
695+
)
696+
697+
# Test without limit (should return only the top match)
698+
result = router("test query")
699+
assert result is not None
700+
assert isinstance(result, RouteChoice)
701+
702+
# Confirm we have Route 1 and sim score
703+
assert result.name == "Route 1"
704+
assert result.similarity_score == 0.9
705+
assert result.function_call is None
706+
707+
@pytest.mark.asyncio
708+
async def test_async_query_parameter(self, router_cls, routes_5, mocker):
709+
"""Test that we return expected values in RouteChoice objects."""
710+
# Create router with mock encoders
711+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
712+
if router_cls == HybridRouter:
713+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
714+
router = router_cls(
715+
encoder=dense_encoder,
716+
sparse_encoder=sparse_encoder,
717+
routes=routes_5,
718+
auto_sync="local",
719+
)
720+
else:
721+
router = router_cls(
722+
encoder=dense_encoder,
723+
routes=routes_5,
724+
auto_sync="local",
725+
)
726+
727+
# Setup a mock for the similarity calculation method
728+
_ = mocker.patch.object(
729+
router,
730+
"_score_routes",
731+
return_value=[
732+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
733+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
734+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
735+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
736+
],
737+
)
738+
739+
# Test without limit (should return only the top match)
740+
result = await router.acall("test query")
741+
assert result is not None
742+
assert isinstance(result, RouteChoice)
743+
744+
# Confirm we have Route 1 and sim score
745+
assert result.name == "Route 1"
746+
assert result.similarity_score == 0.9
747+
assert result.function_call is None
748+
749+
def test_limit_parameter(self, router_cls, routes_5, mocker):
750+
"""Test that the limit parameter works correctly for sync router calls."""
751+
# Create router with mock encoders
752+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
753+
if router_cls == HybridRouter:
754+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
755+
router = router_cls(
756+
encoder=dense_encoder,
757+
sparse_encoder=sparse_encoder,
758+
routes=routes_5,
759+
auto_sync="local",
760+
)
761+
else:
762+
router = router_cls(
763+
encoder=dense_encoder,
764+
routes=routes_5,
765+
auto_sync="local",
766+
)
767+
768+
# Setup a mock for the similarity calculation method
769+
_ = mocker.patch.object(
770+
router,
771+
"_score_routes",
772+
return_value=[
773+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
774+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
775+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
776+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
777+
],
778+
)
779+
780+
# Test without limit (should return only the top match)
781+
result = router("test query")
782+
assert result is not None
783+
assert isinstance(result, RouteChoice)
784+
785+
# Test with limit=2 (should return top 2 matches)
786+
result = router("test query", limit=2)
787+
assert result is not None
788+
assert len(result) == 2
789+
790+
# Test with limit=None (should return all matches)
791+
result = router("test query", limit=None)
792+
assert result is not None
793+
assert len(result) == 4 # Should return all matches
794+
795+
@pytest.mark.asyncio
796+
async def test_async_limit_parameter(self, router_cls, routes_5, mocker):
797+
"""Test that the limit parameter works correctly for async router calls."""
798+
# Create router with mock encoders
799+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
800+
if router_cls == HybridRouter:
801+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
802+
router = router_cls(
803+
encoder=dense_encoder,
804+
sparse_encoder=sparse_encoder,
805+
routes=routes_5,
806+
auto_sync="local",
807+
)
808+
else:
809+
router = router_cls(
810+
encoder=dense_encoder,
811+
routes=routes_5,
812+
auto_sync="local",
813+
)
814+
815+
# Setup a mock for the async similarity calculation method
816+
_ = mocker.patch.object(
817+
router,
818+
"_score_routes",
819+
return_value=[
820+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
821+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
822+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
823+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
824+
],
825+
)
826+
827+
# Test without limit (should return only the top match)
828+
result = await router.acall("test query")
829+
assert result is not None
830+
assert isinstance(result, RouteChoice)
831+
832+
# Test with limit=2 (should return top 2 matches)
833+
result = await router.acall("test query", limit=2)
834+
assert result is not None
835+
assert len(result) == 2
836+
837+
# Test with limit=None (should return all matches)
838+
result = await router.acall("test query", limit=None)
839+
assert result is not None
840+
assert len(result) == 4 # Should return all matches

0 commit comments

Comments
 (0)