Skip to content

Commit be2e196

Browse files
committed
fix: add sim score to tests and sync call
1 parent b587f62 commit be2e196

File tree

3 files changed

+91
-10
lines changed

3 files changed

+91
-10
lines changed

semantic_router/routers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ def _pass_routes(
639639
route.llm = self.llm
640640
# call dynamic route to generate the function_call content
641641
route_choice = route(query=text)
642+
if route_choice is not None and route_choice.similarity_score is None:
643+
route_choice.similarity_score = total_score
642644
passed_routes.append(route_choice)
643645
elif passed and route is not None and simulate_static:
644646
passed_routes.append(

semantic_router/routers/hybrid.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def add(self, routes: List[Route] | Route):
9191
"""Add a route to the local HybridRouter and index.
9292
9393
:param route: The route to add.
94-
:type route: Route
94+
:type route: Route
9595
"""
9696

9797
if self.sparse_encoder is None:
@@ -359,7 +359,7 @@ def __call__(
359359
limit=limit,
360360
)
361361
return route_choices
362-
362+
363363
async def acall(
364364
self,
365365
text: Optional[str] = None,
@@ -396,8 +396,7 @@ async def acall(
396396
if text is None:
397397
raise ValueError("Either text or vector must be provided")
398398
vector, potential_sparse_vector = await self._async_encode(
399-
text=[text],
400-
input_type="queries"
399+
text=[text], input_type="queries"
401400
)
402401
# convert to numpy array if not already
403402
vector = xq_reshape(xq=vector)

tests/unit/test_router.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def routes_4():
205205
Route(name="Route 2", utterances=["Asparagus"]),
206206
]
207207

208+
208209
@pytest.fixture
209210
def routes_5():
210211
return [
@@ -662,9 +663,90 @@ async def test_async_encode(
662663
],
663664
)
664665
class TestRouter:
665-
def test_limit_parameter(
666-
self, router_cls, routes_5, mocker
667-
):
666+
def test_query_parameter(self, router_cls, routes_5, mocker):
667+
"""Test that we return expected values in RouteChoice objects."""
668+
# Create router with mock encoders
669+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
670+
if router_cls == HybridRouter:
671+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
672+
router = router_cls(
673+
encoder=dense_encoder,
674+
sparse_encoder=sparse_encoder,
675+
routes=routes_5,
676+
auto_sync="local",
677+
)
678+
else:
679+
router = router_cls(
680+
encoder=dense_encoder,
681+
routes=routes_5,
682+
auto_sync="local",
683+
)
684+
685+
# Setup a mock for the similarity calculation method
686+
_ = mocker.patch.object(
687+
router,
688+
"_score_routes",
689+
return_value=[
690+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
691+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
692+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
693+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
694+
],
695+
)
696+
697+
# Test without limit (should return only the top match)
698+
result = router("test query")
699+
assert result is not None
700+
assert isinstance(result, RouteChoice)
701+
702+
# Confirm we have Route 1 and sim score
703+
assert result.name == "Route 1"
704+
assert result.similarity_score == 0.9
705+
assert result.function_call is None
706+
707+
@pytest.mark.asyncio
708+
async def test_async_query_parameter(self, router_cls, routes_5, mocker):
709+
"""Test that we return expected values in RouteChoice objects."""
710+
# Create router with mock encoders
711+
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
712+
if router_cls == HybridRouter:
713+
sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder")
714+
router = router_cls(
715+
encoder=dense_encoder,
716+
sparse_encoder=sparse_encoder,
717+
routes=routes_5,
718+
auto_sync="local",
719+
)
720+
else:
721+
router = router_cls(
722+
encoder=dense_encoder,
723+
routes=routes_5,
724+
auto_sync="local",
725+
)
726+
727+
# Setup a mock for the similarity calculation method
728+
_ = mocker.patch.object(
729+
router,
730+
"_score_routes",
731+
return_value=[
732+
("Route 1", 0.9, [0.1, 0.2, 0.3]),
733+
("Route 2", 0.8, [0.4, 0.5, 0.6]),
734+
("Route 3", 0.7, [0.7, 0.8, 0.9]),
735+
("Route 4", 0.6, [1.0, 1.1, 1.2]),
736+
],
737+
)
738+
739+
# Test without limit (should return only the top match)
740+
result = await router.acall("test query")
741+
assert result is not None
742+
assert isinstance(result, RouteChoice)
743+
744+
# Confirm we have Route 1 and sim score
745+
assert result.name == "Route 1"
746+
assert result.similarity_score == 0.9
747+
assert result.function_call is None
748+
749+
def test_limit_parameter(self, router_cls, routes_5, mocker):
668750
"""Test that the limit parameter works correctly for sync router calls."""
669751
# Create router with mock encoders
670752
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")
@@ -711,9 +793,7 @@ def test_limit_parameter(
711793
assert len(result) == 4 # Should return all matches
712794

713795
@pytest.mark.asyncio
714-
async def test_async_limit_parameter(
715-
self, router_cls, routes_5, mocker
716-
):
796+
async def test_async_limit_parameter(self, router_cls, routes_5, mocker):
717797
"""Test that the limit parameter works correctly for async router calls."""
718798
# Create router with mock encoders
719799
dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder")

0 commit comments

Comments
 (0)