@@ -205,6 +205,7 @@ def routes_4():
205
205
Route (name = "Route 2" , utterances = ["Asparagus" ]),
206
206
]
207
207
208
+
208
209
@pytest .fixture
209
210
def routes_5 ():
210
211
return [
@@ -662,9 +663,90 @@ async def test_async_encode(
662
663
],
663
664
)
664
665
class TestRouter :
665
- def test_limit_parameter (
666
- self , router_cls , routes_5 , mocker
667
- ):
666
+ def test_query_parameter (self , router_cls , routes_5 , mocker ):
667
+ """Test that we return expected values in RouteChoice objects."""
668
+ # Create router with mock encoders
669
+ dense_encoder = MockSymmetricDenseEncoder (name = "Dense Encoder" )
670
+ if router_cls == HybridRouter :
671
+ sparse_encoder = MockSymmetricSparseEncoder (name = "Sparse Encoder" )
672
+ router = router_cls (
673
+ encoder = dense_encoder ,
674
+ sparse_encoder = sparse_encoder ,
675
+ routes = routes_5 ,
676
+ auto_sync = "local" ,
677
+ )
678
+ else :
679
+ router = router_cls (
680
+ encoder = dense_encoder ,
681
+ routes = routes_5 ,
682
+ auto_sync = "local" ,
683
+ )
684
+
685
+ # Setup a mock for the similarity calculation method
686
+ _ = mocker .patch .object (
687
+ router ,
688
+ "_score_routes" ,
689
+ return_value = [
690
+ ("Route 1" , 0.9 , [0.1 , 0.2 , 0.3 ]),
691
+ ("Route 2" , 0.8 , [0.4 , 0.5 , 0.6 ]),
692
+ ("Route 3" , 0.7 , [0.7 , 0.8 , 0.9 ]),
693
+ ("Route 4" , 0.6 , [1.0 , 1.1 , 1.2 ]),
694
+ ],
695
+ )
696
+
697
+ # Test without limit (should return only the top match)
698
+ result = router ("test query" )
699
+ assert result is not None
700
+ assert isinstance (result , RouteChoice )
701
+
702
+ # Confirm we have Route 1 and sim score
703
+ assert result .name == "Route 1"
704
+ assert result .similarity_score == 0.9
705
+ assert result .function_call is None
706
+
707
+ @pytest .mark .asyncio
708
+ async def test_async_query_parameter (self , router_cls , routes_5 , mocker ):
709
+ """Test that we return expected values in RouteChoice objects."""
710
+ # Create router with mock encoders
711
+ dense_encoder = MockSymmetricDenseEncoder (name = "Dense Encoder" )
712
+ if router_cls == HybridRouter :
713
+ sparse_encoder = MockSymmetricSparseEncoder (name = "Sparse Encoder" )
714
+ router = router_cls (
715
+ encoder = dense_encoder ,
716
+ sparse_encoder = sparse_encoder ,
717
+ routes = routes_5 ,
718
+ auto_sync = "local" ,
719
+ )
720
+ else :
721
+ router = router_cls (
722
+ encoder = dense_encoder ,
723
+ routes = routes_5 ,
724
+ auto_sync = "local" ,
725
+ )
726
+
727
+ # Setup a mock for the similarity calculation method
728
+ _ = mocker .patch .object (
729
+ router ,
730
+ "_score_routes" ,
731
+ return_value = [
732
+ ("Route 1" , 0.9 , [0.1 , 0.2 , 0.3 ]),
733
+ ("Route 2" , 0.8 , [0.4 , 0.5 , 0.6 ]),
734
+ ("Route 3" , 0.7 , [0.7 , 0.8 , 0.9 ]),
735
+ ("Route 4" , 0.6 , [1.0 , 1.1 , 1.2 ]),
736
+ ],
737
+ )
738
+
739
+ # Test without limit (should return only the top match)
740
+ result = await router .acall ("test query" )
741
+ assert result is not None
742
+ assert isinstance (result , RouteChoice )
743
+
744
+ # Confirm we have Route 1 and sim score
745
+ assert result .name == "Route 1"
746
+ assert result .similarity_score == 0.9
747
+ assert result .function_call is None
748
+
749
+ def test_limit_parameter (self , router_cls , routes_5 , mocker ):
668
750
"""Test that the limit parameter works correctly for sync router calls."""
669
751
# Create router with mock encoders
670
752
dense_encoder = MockSymmetricDenseEncoder (name = "Dense Encoder" )
@@ -711,9 +793,7 @@ def test_limit_parameter(
711
793
assert len (result ) == 4 # Should return all matches
712
794
713
795
@pytest .mark .asyncio
714
- async def test_async_limit_parameter (
715
- self , router_cls , routes_5 , mocker
716
- ):
796
+ async def test_async_limit_parameter (self , router_cls , routes_5 , mocker ):
717
797
"""Test that the limit parameter works correctly for async router calls."""
718
798
# Create router with mock encoders
719
799
dense_encoder = MockSymmetricDenseEncoder (name = "Dense Encoder" )
0 commit comments