Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 24 additions & 112 deletions docs/custom_tool_guides/tool_guide.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Custom tools and retrieval sources
# Custom Tools
Follow these instructions to create your own custom tools.

## Step 1: Choose a Tool to Implement
Expand All @@ -25,115 +25,48 @@ There are three types of tools:

## Step 3: Implement the Tool

Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link might change). The specific subfolder used will depend on the type of tool you're implementing.
Add your tool implementation [here](https://github.com/cohere-ai/toolkit/tree/main/src/community/tools) (please note that this link is subject to change).

If you need to install a new module to run your tool, execute the following command and run `make dev` again.
If you need to install a new library to run your tool, execute the following command and run `make dev` again.

```bash
poetry add <MODULE> --group community
```
### Implementing a Tool

If you're working on a File or Data Loader, follow the steps outlined in [Implementing a Retriever](#implementing-a-retriever).
Add the implementation inside a tool class that inherits from `BaseTool`. This class will need to implement
the `call()` method, which should return a list of dictionary results.

If you're implementing a Function Tool, refer to the steps in [Implementing a Function Tool](#implementing-a-function-tool).

### Implementing a Retriever

Add the implementation inside a tool class that inherits `BaseRetrieval` and needs to implement the function `def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:`

You can define custom configurations for your tool within the `__init__` function. Set the exact values for these variables during [Step 4](#step-4-making-your-tool-available).

You can also develop a tool that requires a token or authentication. To do this, simply set your variable in the .env file.

For example, for Wikipedia we have a custom configuration:

```python
class LangChainWikiRetriever(BaseRetrieval):
"""
This class retrieves documents from Wikipedia using the langchain package.
This requires wikipedia package to be installed.
"""

def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
wiki_retriever = WikipediaRetriever()
docs = wiki_retriever.get_relevant_documents(query)
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
documents = text_splitter.split_documents(docs)
return [
{
"text": doc.page_content,
"title": doc.metadata.get("title", None),
"url": doc.metadata.get("source", None),
}
for doc in documents
]
```

And for internet search, we need an API key
For example, let's look at the community-implemented `ArxivRetriever`:

```python
class TavilyInternetSearch(BaseRetrieval):
def __init__(self):
if "TAVILY_API_KEY" not in os.environ:
raise ValueError("Please set the TAVILY_API_KEY environment variable.")

self.api_key = os.environ["TAVILY_API_KEY"]
self.client = TavilyClient(api_key=self.api_key)

def retrieve_documents(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
content = self.client.search(query=query, search_depth="advanced")
from typing import Any, Dict, List

if "results" not in content:
return []
from langchain_community.utilities import ArxivAPIWrapper

return [
{
"url": result["url"],
"text": result["content"],
}
for result in content["results"]
```

Note that all Retrievers should return a list of Dicts, and each Dict should contain at least a `text` key.

### Implementing a Function Tool
from community.tools import BaseTool

Add the implementation inside a tool class that inherits `BaseFunctionTool` and needs to implement the function `def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:`

For example, for calculator

```python
from typing import Any
from py_expression_eval import Parser
from typing import List, Dict

from backend.tools.function_tools.base import BaseFunctionTool
class ArxivRetriever(BaseTool):
def __init__(self):
self.client = ArxivAPIWrapper()

class CalculatorFunctionTool(BaseFunctionTool):
"""
Function Tool that evaluates mathematical expressions.
"""
@classmethod
# If your tool requires any environment variables such as API keys,
# you will need to assert that they're not None here
def is_available(cls) -> bool:
return True

# Your tool needs to implement this call() method
def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:
math_parser = Parser()
to_evaluate = parameters.get("code", "").replace("pi", "PI").replace("e", "E")
result = []
try:
result = {"result": math_parser.parse(to_evaluate).evaluate({})}
except Exception:
result = {"result": "Parsing error - syntax not allowed."}
return result
result = self.client.run(parameters)

return [{"text": result}] # <- Return list of results, in this case there is only one
```

## Step 4: Making Your Tool Available

To make your tool available, add its definition to the tools config [here](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py).
To make your tool available, add its definition to the community tools [config.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py).

Start by adding the tool name to the `ToolName` enum found at the top of the file.

Expand All @@ -149,27 +82,6 @@ Next, include the tool configurations in the `AVAILABLE_TOOLS` list. The definit
- Description: A brief description of the tool.
- Env_vars: A list of secrets required by the tool.

Function tool with custom parameter definitions:

```python
ToolName.Python_Interpreter: ManagedTool(
name=ToolName.Python_Interpreter,
implementation=PythonInterpreterFunctionTool,
parameter_definitions={
"code": {
"description": "Python code to execute using an interpreter",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=PythonInterpreterFunctionTool.is_available(),
error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.",
category=Category.Function,
description="Runs python code in a sandbox.",
)
```

## Step 5: Test Your Tool!

Now, when you run the toolkit, all the visible tools, including the one you just added, should be available!
Expand Down Expand Up @@ -207,4 +119,4 @@ curl --location 'http://localhost:8000/chat-stream' \

## Step 6 (extra): Add Unit tests

If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few cases.
If you would like to go above and beyond, it would be helpful to add some unit tests to ensure that your tool is working as expected. Create a file [here](https://github.com/cohere-ai/cohere-toolkit/tree/main/src/community/tests/tools) and add a few test cases.
File renamed without changes.
7 changes: 3 additions & 4 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from fastapi import HTTPException

from backend.chat.base import BaseChat
from backend.chat.collate import combine_documents
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_deployment
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger
from backend.tools.retrieval.collate import combine_documents


class CustomChat(BaseChat):
Expand Down Expand Up @@ -84,11 +84,10 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:

all_documents = {}
# TODO: call in parallel and error handling
# TODO: merge with regular function tools after multihop implemented
for retriever in retrievers:
for query in queries:
all_documents.setdefault(query, []).extend(
retriever.retrieve_documents(query)
)
all_documents.setdefault(query, []).extend(retriever.call(query))

# Collate Documents
documents = combine_documents(all_documents, deployment_model)
Expand Down
16 changes: 7 additions & 9 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from enum import StrEnum

from backend.schemas.tool import Category, ManagedTool
from backend.tools.function_tools import (
CalculatorFunctionTool,
PythonInterpreterFunctionTool,
)
from backend.tools.retrieval import (
from backend.tools import (
Calculator,
LangChainVectorDBRetriever,
LangChainWikiRetriever,
PythonInterpreter,
TavilyInternetSearch,
)

Expand Down Expand Up @@ -56,7 +54,7 @@ class ToolName(StrEnum):
),
ToolName.Python_Interpreter: ManagedTool(
name=ToolName.Python_Interpreter,
implementation=PythonInterpreterFunctionTool,
implementation=PythonInterpreter,
parameter_definitions={
"code": {
"description": "Python code to execute using an interpreter",
Expand All @@ -65,14 +63,14 @@ class ToolName(StrEnum):
}
},
is_visible=True,
is_available=PythonInterpreterFunctionTool.is_available(),
is_available=PythonInterpreter.is_available(),
error_message="PythonInterpreterFunctionTool not available, please make sure to set the PYTHON_INTERPRETER_URL environment variable.",
category=Category.Function,
description="Runs python code in a sandbox.",
),
ToolName.Calculator: ManagedTool(
name=ToolName.Calculator,
implementation=CalculatorFunctionTool,
implementation=Calculator,
parameter_definitions={
"code": {
"description": "Arithmetic expression to evaluate",
Expand All @@ -81,7 +79,7 @@ class ToolName(StrEnum):
}
},
is_visible=True,
is_available=CalculatorFunctionTool.is_available(),
is_available=Calculator.is_available(),
error_message="CalculatorFunctionTool not available.",
category=Category.Function,
description="Evaluate arithmetic expressions.",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from backend.tools.function_tools import CalculatorFunctionTool
from backend.tools import Calculator


def test_calculator() -> None:
calculator = CalculatorFunctionTool()
calculator = Calculator()
result = calculator.call({"code": "2+2"})
assert result == {"result": 4}


def test_calculator_invalid_syntax() -> None:
calculator = CalculatorFunctionTool()
calculator = Calculator()
result = calculator.call({"code": "2+"})
assert result == {"result": "Parsing error - syntax not allowed."}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from backend.chat import collate
from backend.model_deployments import CohereDeployment
from backend.tools.retrieval import collate

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import pytest
from langchain_core.documents.base import Document

from backend.tools.retrieval.lang_chain import (
LangChainVectorDBRetriever,
LangChainWikiRetriever,
)
from backend.tools import LangChainVectorDBRetriever, LangChainWikiRetriever

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
Expand Down Expand Up @@ -53,10 +50,10 @@ def test_wiki_retriever() -> None:
wiki_retriever_mock.get_relevant_documents.return_value = mock_docs

with patch(
"backend.tools.retrieval.lang_chain.WikipediaRetriever",
"backend.tools.lang_chain.WikipediaRetriever",
return_value=wiki_retriever_mock,
):
result = retriever.retrieve_documents(query)
result = retriever.call(query)

assert result == expected_docs

Expand All @@ -71,10 +68,10 @@ def test_wiki_retriever_no_docs() -> None:
wiki_retriever_mock.get_relevant_documents.return_value = mock_docs

with patch(
"backend.tools.retrieval.lang_chain.WikipediaRetriever",
"backend.tools.lang_chain.WikipediaRetriever",
return_value=wiki_retriever_mock,
):
result = retriever.retrieve_documents(query)
result = retriever.call(query)

assert result == []

Expand Down Expand Up @@ -134,7 +131,7 @@ def test_vector_db_retriever() -> None:
mock_db = MagicMock()
mock_from_documents.return_value = mock_db
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = retriever.retrieve_documents(query)
result = retriever.call(query)

assert result == expected_docs

Expand All @@ -155,6 +152,6 @@ def test_vector_db_retriever_no_docs() -> None:
mock_db = MagicMock()
mock_from_documents.return_value = mock_db
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = retriever.retrieve_documents(query)
result = retriever.call(query)

assert result == []
12 changes: 12 additions & 0 deletions src/backend/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from backend.tools.calculator import Calculator
from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever
from backend.tools.python_interpreter import PythonInterpreter
from backend.tools.tavily import TavilyInternetSearch

__all__ = [
"Calculator",
"PythonInterpreter",
"LangChainVectorDBRetriever",
"LangChainWikiRetriever",
"TavilyInternetSearch",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import Any, Dict, List


class BaseFunctionTool:
"""Base for all retrieval options."""
class BaseTool:
"""
Abstract base class for all Tools.
"""

@classmethod
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from py_expression_eval import Parser

from backend.tools.function_tools.base import BaseFunctionTool
from backend.tools.base import BaseTool


class CalculatorFunctionTool(BaseFunctionTool):
class CalculatorFunctionTool(BaseTool):
"""
Function Tool that evaluates mathematical expressions.
"""
Expand All @@ -17,6 +17,7 @@ def is_available(cls) -> bool:
def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:
math_parser = Parser()
to_evaluate = parameters.get("code", "").replace("pi", "PI").replace("e", "E")

result = []
try:
result = {"result": math_parser.parse(to_evaluate).evaluate({})}
Expand Down
9 changes: 0 additions & 9 deletions src/backend/tools/function_tools/__init__.py

This file was deleted.

Loading