|
24 | 24 | from semantic_router.llms import BaseLLM, OpenAILLM
|
25 | 25 | from semantic_router.route import Route
|
26 | 26 | from semantic_router.routers import HybridRouter, RouterConfig, SemanticRouter
|
27 |
| -from semantic_router.schema import SparseEmbedding |
| 27 | +from semantic_router.schema import RouteChoice, SparseEmbedding |
28 | 28 |
|
29 | 29 | PINECONE_SLEEP = 8
|
30 | 30 | RETRY_COUNT = 10
|
@@ -206,6 +206,16 @@ def routes_4():
|
206 | 206 | ]
|
207 | 207 |
|
208 | 208 |
|
| 209 | +@pytest.fixture |
| 210 | +def routes_5(): |
| 211 | + return [ |
| 212 | + Route(name="Route 1", utterances=["Hello", "Hi"], metadata={"type": "default"}), |
| 213 | + Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]), |
| 214 | + Route(name="Route 3", utterances=["Hello", "Hi"]), |
| 215 | + Route(name="Route 4", utterances=["Goodbye", "Bye", "Au revoir"]), |
| 216 | + ] |
| 217 | + |
| 218 | + |
209 | 219 | @pytest.fixture
|
210 | 220 | def route_single_utterance():
|
211 | 221 | return [
|
@@ -643,3 +653,188 @@ async def test_async_encode(
|
643 | 653 | assert sparse_encode_queries_spy.called
|
644 | 654 | else:
|
645 | 655 | assert sparse_call_spy.called
|
| 656 | + |
| 657 | + |
| 658 | +@pytest.mark.parametrize( |
| 659 | + "router_cls", |
| 660 | + [ |
| 661 | + HybridRouter, |
| 662 | + SemanticRouter, |
| 663 | + ], |
| 664 | +) |
| 665 | +class TestRouter: |
| 666 | + def test_query_parameter(self, router_cls, routes_5, mocker): |
| 667 | + """Test that we return expected values in RouteChoice objects.""" |
| 668 | + # Create router with mock encoders |
| 669 | + dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder") |
| 670 | + if router_cls == HybridRouter: |
| 671 | + sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder") |
| 672 | + router = router_cls( |
| 673 | + encoder=dense_encoder, |
| 674 | + sparse_encoder=sparse_encoder, |
| 675 | + routes=routes_5, |
| 676 | + auto_sync="local", |
| 677 | + ) |
| 678 | + else: |
| 679 | + router = router_cls( |
| 680 | + encoder=dense_encoder, |
| 681 | + routes=routes_5, |
| 682 | + auto_sync="local", |
| 683 | + ) |
| 684 | + |
| 685 | + # Setup a mock for the similarity calculation method |
| 686 | + _ = mocker.patch.object( |
| 687 | + router, |
| 688 | + "_score_routes", |
| 689 | + return_value=[ |
| 690 | + ("Route 1", 0.9, [0.1, 0.2, 0.3]), |
| 691 | + ("Route 2", 0.8, [0.4, 0.5, 0.6]), |
| 692 | + ("Route 3", 0.7, [0.7, 0.8, 0.9]), |
| 693 | + ("Route 4", 0.6, [1.0, 1.1, 1.2]), |
| 694 | + ], |
| 695 | + ) |
| 696 | + |
| 697 | + # Test without limit (should return only the top match) |
| 698 | + result = router("test query") |
| 699 | + assert result is not None |
| 700 | + assert isinstance(result, RouteChoice) |
| 701 | + |
| 702 | + # Confirm we have Route 1 and sim score |
| 703 | + assert result.name == "Route 1" |
| 704 | + assert result.similarity_score == 0.9 |
| 705 | + assert result.function_call is None |
| 706 | + |
| 707 | + @pytest.mark.asyncio |
| 708 | + async def test_async_query_parameter(self, router_cls, routes_5, mocker): |
| 709 | + """Test that we return expected values in RouteChoice objects.""" |
| 710 | + # Create router with mock encoders |
| 711 | + dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder") |
| 712 | + if router_cls == HybridRouter: |
| 713 | + sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder") |
| 714 | + router = router_cls( |
| 715 | + encoder=dense_encoder, |
| 716 | + sparse_encoder=sparse_encoder, |
| 717 | + routes=routes_5, |
| 718 | + auto_sync="local", |
| 719 | + ) |
| 720 | + else: |
| 721 | + router = router_cls( |
| 722 | + encoder=dense_encoder, |
| 723 | + routes=routes_5, |
| 724 | + auto_sync="local", |
| 725 | + ) |
| 726 | + |
| 727 | + # Setup a mock for the similarity calculation method |
| 728 | + _ = mocker.patch.object( |
| 729 | + router, |
| 730 | + "_score_routes", |
| 731 | + return_value=[ |
| 732 | + ("Route 1", 0.9, [0.1, 0.2, 0.3]), |
| 733 | + ("Route 2", 0.8, [0.4, 0.5, 0.6]), |
| 734 | + ("Route 3", 0.7, [0.7, 0.8, 0.9]), |
| 735 | + ("Route 4", 0.6, [1.0, 1.1, 1.2]), |
| 736 | + ], |
| 737 | + ) |
| 738 | + |
| 739 | + # Test without limit (should return only the top match) |
| 740 | + result = await router.acall("test query") |
| 741 | + assert result is not None |
| 742 | + assert isinstance(result, RouteChoice) |
| 743 | + |
| 744 | + # Confirm we have Route 1 and sim score |
| 745 | + assert result.name == "Route 1" |
| 746 | + assert result.similarity_score == 0.9 |
| 747 | + assert result.function_call is None |
| 748 | + |
| 749 | + def test_limit_parameter(self, router_cls, routes_5, mocker): |
| 750 | + """Test that the limit parameter works correctly for sync router calls.""" |
| 751 | + # Create router with mock encoders |
| 752 | + dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder") |
| 753 | + if router_cls == HybridRouter: |
| 754 | + sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder") |
| 755 | + router = router_cls( |
| 756 | + encoder=dense_encoder, |
| 757 | + sparse_encoder=sparse_encoder, |
| 758 | + routes=routes_5, |
| 759 | + auto_sync="local", |
| 760 | + ) |
| 761 | + else: |
| 762 | + router = router_cls( |
| 763 | + encoder=dense_encoder, |
| 764 | + routes=routes_5, |
| 765 | + auto_sync="local", |
| 766 | + ) |
| 767 | + |
| 768 | + # Setup a mock for the similarity calculation method |
| 769 | + _ = mocker.patch.object( |
| 770 | + router, |
| 771 | + "_score_routes", |
| 772 | + return_value=[ |
| 773 | + ("Route 1", 0.9, [0.1, 0.2, 0.3]), |
| 774 | + ("Route 2", 0.8, [0.4, 0.5, 0.6]), |
| 775 | + ("Route 3", 0.7, [0.7, 0.8, 0.9]), |
| 776 | + ("Route 4", 0.6, [1.0, 1.1, 1.2]), |
| 777 | + ], |
| 778 | + ) |
| 779 | + |
| 780 | + # Test without limit (should return only the top match) |
| 781 | + result = router("test query") |
| 782 | + assert result is not None |
| 783 | + assert isinstance(result, RouteChoice) |
| 784 | + |
| 785 | + # Test with limit=2 (should return top 2 matches) |
| 786 | + result = router("test query", limit=2) |
| 787 | + assert result is not None |
| 788 | + assert len(result) == 2 |
| 789 | + |
| 790 | + # Test with limit=None (should return all matches) |
| 791 | + result = router("test query", limit=None) |
| 792 | + assert result is not None |
| 793 | + assert len(result) == 4 # Should return all matches |
| 794 | + |
| 795 | + @pytest.mark.asyncio |
| 796 | + async def test_async_limit_parameter(self, router_cls, routes_5, mocker): |
| 797 | + """Test that the limit parameter works correctly for async router calls.""" |
| 798 | + # Create router with mock encoders |
| 799 | + dense_encoder = MockSymmetricDenseEncoder(name="Dense Encoder") |
| 800 | + if router_cls == HybridRouter: |
| 801 | + sparse_encoder = MockSymmetricSparseEncoder(name="Sparse Encoder") |
| 802 | + router = router_cls( |
| 803 | + encoder=dense_encoder, |
| 804 | + sparse_encoder=sparse_encoder, |
| 805 | + routes=routes_5, |
| 806 | + auto_sync="local", |
| 807 | + ) |
| 808 | + else: |
| 809 | + router = router_cls( |
| 810 | + encoder=dense_encoder, |
| 811 | + routes=routes_5, |
| 812 | + auto_sync="local", |
| 813 | + ) |
| 814 | + |
| 815 | + # Setup a mock for the async similarity calculation method |
| 816 | + _ = mocker.patch.object( |
| 817 | + router, |
| 818 | + "_score_routes", |
| 819 | + return_value=[ |
| 820 | + ("Route 1", 0.9, [0.1, 0.2, 0.3]), |
| 821 | + ("Route 2", 0.8, [0.4, 0.5, 0.6]), |
| 822 | + ("Route 3", 0.7, [0.7, 0.8, 0.9]), |
| 823 | + ("Route 4", 0.6, [1.0, 1.1, 1.2]), |
| 824 | + ], |
| 825 | + ) |
| 826 | + |
| 827 | + # Test without limit (should return only the top match) |
| 828 | + result = await router.acall("test query") |
| 829 | + assert result is not None |
| 830 | + assert isinstance(result, RouteChoice) |
| 831 | + |
| 832 | + # Test with limit=2 (should return top 2 matches) |
| 833 | + result = await router.acall("test query", limit=2) |
| 834 | + assert result is not None |
| 835 | + assert len(result) == 2 |
| 836 | + |
| 837 | + # Test with limit=None (should return all matches) |
| 838 | + result = await router.acall("test query", limit=None) |
| 839 | + assert result is not None |
| 840 | + assert len(result) == 4 # Should return all matches |
0 commit comments