diff --git a/README.md b/README.md index 239bc6f2c7..35a2edf5a6 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Toolkit is a deployable all-in-one RAG application that enables users to quickly - [How to setup Google Drive](/docs/custom_tool_guides/google_drive.md) - [How to setup Gmail](/docs/custom_tool_guides/gmail.md) - [How to setup Slack Tool](/docs/custom_tool_guides/slack.md) + - [How to setup Github Tool](/docs/custom_tool_guides/github.md) - [How to setup Google Text-to-Speech](/docs/text_to_speech.md) - [How to add authentication](/docs/auth_guide.md) - [How to deploy toolkit services](/docs/service_deployments.md) diff --git a/docs/custom_tool_guides/github.md b/docs/custom_tool_guides/github.md new file mode 100644 index 0000000000..c854522626 --- /dev/null +++ b/docs/custom_tool_guides/github.md @@ -0,0 +1,50 @@ +# Github Tool Setup + +To set up the Github tool you will need a Github application. Follow the steps below to set it up: + +## 1. Create a Github App + +Head to the [Github Settings](https://github.com/settings/apps) and create a new app. +Specify App [permissions](https://docs.github.com/rest/overview/permissions-required-for-github-apps), Callback URL (for local setup - http://localhost:8000/v1/tool/auth). +Uncheck the `Webhook->Active` option. After creating the app, you will see the `General` section. Copy the `Client ID`, generate and copy `Client Secret` values. +That will be used for the environment variables specified below. +This tool also support OAuth Apps. See the [documentation](https://docs.github.com/en/apps/oauth-apps) for more information. + +## 2. Set Up Environment Variables +Set the configuration in the `configuration.yaml` +```yaml +github: + default_repos: + - repo1 + - repo2 + user_scopes: + - public_repo + - read:org +``` + +Then set the following secrets variables. You can either set the below values in your `secrets.yaml` file: +```yaml +github: + client_id: + client_secret: +``` +or update your `.env` configuration to contain: +```dotenv +GITHUB_CLIENT_ID= +GITHUB_CLIENT_SECRET= +GITHUB_DEFAULT_REPOS=["repo1","repo2"] +GITHUB_USER_SCOPES=["public_repo","read:org"] +``` +Please note if the default repos are not set, the tool will process over all user repos. + +## 3. Run the Backend and Frontend + +run next command to start the backend and frontend: + +```bash +make dev +``` + +## 4. Troubleshooting + +If you encounter any issues with OAuth, please check the following [link](https://api.Github.com/authentication/oauth-v2#errors) diff --git a/poetry.lock b/poetry.lock index 60ad60700c..1a70de7ab2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4084,43 +4084,31 @@ python-versions = ">=3.9" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, - {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, - {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, - {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, - {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, - {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, - {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, - {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, - {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, - {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, - {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, - {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, - {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, @@ -4854,6 +4842,25 @@ azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0 toml = ["tomli (>=2.0.1)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "pygithub" +version = "2.5.0" +description = "Use the full Github API v3" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyGithub-2.5.0-py3-none-any.whl", hash = "sha256:b0b635999a658ab8e08720bdd3318893ff20e2275f6446fcf35bf3f44f2c0fd2"}, + {file = "pygithub-2.5.0.tar.gz", hash = "sha256:e1613ac508a9be710920d26eb18b1905ebd9926aa49398e88151c1b526aad3cf"}, +] + +[package.dependencies] +Deprecated = "*" +pyjwt = {version = ">=2.4.0", extras = ["crypto"]} +pynacl = ">=1.4.0" +requests = ">=2.14.0" +typing-extensions = ">=4.0.0" +urllib3 = ">=1.26.0" + [[package]] name = "pygments" version = "2.18.0" @@ -4879,12 +4886,41 @@ files = [ {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] +[package.dependencies] +cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""} + [package.extras] crypto = ["cryptography (>=3.4.0)"] dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pynacl" +version = "1.5.0" +description = "Python binding to the Networking and Cryptography (NaCl) library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"}, + {file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"}, +] + +[package.dependencies] +cffi = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] + [[package]] name = "pyparsing" version = "3.2.0" @@ -6259,6 +6295,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6957,4 +6998,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "9cf4c3aa73e1b1181ad0c993c36b9585edf7827f33a4cb1cc1f0c67c6bdc8f87" +content-hash = "d8269dbcb256b0afb1ee469c4bd02e3394c0abf781a7250846fd41873a0b65c2" diff --git a/pyproject.toml b/pyproject.toml index f7adcecc21..42e4e4f06a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ llama-index-embeddings-cohere = "^0.2.1" google-cloud-texttospeech = "^2.18.0" slack-sdk = "^3.33.1" onnxruntime = "1.19.2" +pygithub = "^2.5.0" [tool.poetry.group.dev] optional = true diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index 2dfb85bb7d..767edb5012 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -37,6 +37,12 @@ tools: slack: user_scopes: - search:read + github: + user_scopes: + - public_repo + default_repos: + - cohere-ai/cohere-toolkit + - EugeneLightsOn/cohere-toolkit # To disable the use of the tools preamble, set it to false use_tools_preamble: true feature_flags: diff --git a/src/backend/config/secrets.template.yaml b/src/backend/config/secrets.template.yaml index fcc27fb21f..c4f22b23ed 100644 --- a/src/backend/config/secrets.template.yaml +++ b/src/backend/config/secrets.template.yaml @@ -37,6 +37,9 @@ tools: gmail: client_id: client_secret: + github: + client_id: + client_secret: auth: secret_key: google_oauth: diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 087d34a9ca..189bdd441e 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -191,7 +191,7 @@ class SlackSettings(BaseSettings, BaseModel): default=None, validation_alias=AliasChoices("SLACK_CLIENT_SECRET", "client_secret"), ) - user_scopes: Optional[str] = Field( + user_scopes: Optional[List[str]] = Field( default=None, validation_alias=AliasChoices( "SLACK_USER_SCOPES", "scopes" @@ -199,6 +199,30 @@ class SlackSettings(BaseSettings, BaseModel): ) +class GithubSettings(BaseSettings, BaseModel): + model_config = SETTINGS_CONFIG + client_id: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("GITHUB_CLIENT_ID", "client_id"), + ) + client_secret: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("GITHUB_CLIENT_SECRET", "client_secret"), + ) + user_scopes: Optional[List[str]] = Field( + default=None, + validation_alias=AliasChoices( + "GITHUB_USER_SCOPES", "user_scopes" + ), + ) + default_repos: Optional[List[str]] = Field( + default=None, + validation_alias=AliasChoices( + "GITHUB_DEFAULT_REPOS", "default_repos" + ), + ) + + class TavilyWebSearchSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG api_key: Optional[str] = Field( @@ -272,6 +296,9 @@ class ToolSettings(BaseSettings, BaseModel): slack: Optional[SlackSettings] = Field( default=SlackSettings() ) + github: Optional[GithubSettings] = Field( + default=GithubSettings() + ) gmail: Optional[GmailSettings] = Field( default=GmailSettings() ) diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index c955757058..742f8339f5 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -18,6 +18,7 @@ TavilyWebSearch, WebScrapeTool, ) +from backend.tools.github.tool import GithubTool logger = LoggerFactory().get_logger() @@ -38,6 +39,7 @@ class Tool(Enum): Hybrid_Web_Search = HybridWebSearch Slack = SlackTool Gmail = GmailTool + Github = GithubTool def get_available_tools() -> dict[str, ToolDefinition]: diff --git a/src/backend/main.py b/src/backend/main.py index 1c3f845845..eb5c336bbb 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -83,9 +83,10 @@ def create_app() -> FastAPI: # Dynamically set router dependencies # These values must be set in config/routers.py dependencies_type = "default" + settings = Settings() if is_authentication_enabled(): # Required to save temporary OAuth state in session - auth_secret = Settings().get('auth.secret_key') + auth_secret = settings.get('auth.secret_key') app.add_middleware(SessionMiddleware, secret_key=auth_secret) dependencies_type = "auth" for router in routers: diff --git a/src/backend/tests/unit/tools/test_lang_chain.py b/src/backend/tests/unit/tools/test_lang_chain.py index dc00bac23a..31ede1ce27 100644 --- a/src/backend/tests/unit/tools/test_lang_chain.py +++ b/src/backend/tests/unit/tools/test_lang_chain.py @@ -72,7 +72,7 @@ async def test_wiki_retriever_no_docs() -> None: ): result = await retriever.call({"query": query}, ctx) - assert result == ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.') + assert result == [ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.').model_dump()] @@ -156,4 +156,4 @@ async def test_vector_db_retriever_no_docs() -> None: mock_db.as_retriever().get_relevant_documents.return_value = mock_docs result = await retriever.call({"query": query}, ctx) - assert result == ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.') + assert result == [ToolError(type=ToolErrorCode.OTHER, success=False, text='No results found.', details='No results found for the given params.').model_dump()] diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index 396e278d79..d52f615526 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -164,7 +164,8 @@ def get_tool_error(cls, details: str, text: str = "Error calling tool", error_ty @classmethod def get_no_results_error(cls): - return ToolError(text="No results found.", details="No results found for the given params.") + tool_error = ToolError(text="No results found.", details="No results found for the given params.").model_dump() + return [tool_error] @abstractmethod async def call( diff --git a/src/backend/tools/github/__init__.py b/src/backend/tools/github/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/tools/github/auth.py b/src/backend/tools/github/auth.py new file mode 100644 index 0000000000..2fbd1f2fbe --- /dev/null +++ b/src/backend/tools/github/auth.py @@ -0,0 +1,163 @@ +import datetime +import json +import urllib.parse + +import requests +from fastapi import Request + +from backend.config.settings import Settings +from backend.crud import tool_auth as tool_auth_crud +from backend.database_models import ToolAuth +from backend.database_models.database import DBSessionDep +from backend.database_models.tool_auth import ToolAuth as ToolAuthModel +from backend.schemas.tool_auth import UpdateToolAuth +from backend.services.auth.crypto import encrypt +from backend.services.logger.utils import LoggerFactory +from backend.tools.base import BaseToolAuthentication +from backend.tools.github.constants import GITHUB_TOOL_ID +from backend.tools.utils.mixins import ToolAuthenticationCacheMixin + +logger = LoggerFactory().get_logger() + + +class GithubAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin): + TOOL_ID = GITHUB_TOOL_ID + AUTH_ENDPOINT = "https://github.com/login/oauth/authorize" + TOKEN_ENDPOINT = "https://github.com/login/oauth/access_token" + DEFAULT_USER_SCOPES = ['public_repo', 'read:org'] + + def __init__(self): + super().__init__() + + self.GITHUB_CLIENT_ID = Settings().get('tools.github.client_id') + self.GITHUB_CLIENT_SECRET = Settings().get('tools.github.client_secret') + self.USER_SCOPES = Settings().get('tools.github.user_scopes') or self.DEFAULT_USER_SCOPES + self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth" + + if any([ + self.GITHUB_CLIENT_ID is None, + self.GITHUB_CLIENT_SECRET is None + ]): + raise ValueError( + "GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET must be set to use Slack Tool Auth." + ) + + def get_auth_url(self, user_id: str) -> str: + key = self.insert_tool_auth_cache(user_id, self.TOOL_ID) + state = {"key": key} + + params = { + "client_id": self.GITHUB_CLIENT_ID, + "scope": " ".join(self.USER_SCOPES or []), + "redirect_uri": self.REDIRECT_URL, + "state": json.dumps(state), + } + + return f"{self.AUTH_ENDPOINT}?{urllib.parse.urlencode(params)}" + + def retrieve_auth_token( + self, request: Request, session: DBSessionDep, user_id: str + ) -> str: + if request.query_params.get("error"): + error = request.query_params.get("error") or "Unknown error" + logger.error(event=f"[Github Tool] Auth token error: {error}.") + return error + + body = { + "code": request.query_params.get("code"), + "client_id": self.GITHUB_CLIENT_ID, + "client_secret": self.GITHUB_CLIENT_SECRET, + } + + url_encoded_body = urllib.parse.urlencode(body) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + response = requests.post(self.TOKEN_ENDPOINT, data=url_encoded_body, headers=headers) + + response_body = response.json() + + if response.status_code != 200: + logger.error( + event=f"[Github Tool] Error retrieving auth token: {response_body}" + ) + return str(response) + + token = response_body.get("access_token", None) + token_type = response_body.get("token_type", None) + refresh_token = response_body.get("refresh_token", "") + expires_in = response_body.get("expires_in", 31536000) + + if token is None: + logger.error( + event=f"[Github Tool] Error retrieving auth token: {response_body}" + ) + return str(response) + + tool_auth_crud.create_tool_auth( + session, + ToolAuthModel( + user_id=user_id, + tool_id=self.TOOL_ID, + token_type=token_type, + encrypted_access_token=encrypt(token), + encrypted_refresh_token=encrypt(refresh_token), + expires_at=datetime.datetime.now() + + datetime.timedelta(seconds=expires_in) + ), + ) + + return "" + + def try_refresh_token(self, session: DBSessionDep, user_id: str, tool_auth: ToolAuth) -> bool: + body = { + "client_id": self.GITHUB_CLIENT_ID, + "client_secret": self.GITHUB_CLIENT_SECRET, + "refresh_token": tool_auth.refresh_token, + "grant_type": "refresh_token", + } + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + url_encoded_body = urllib.parse.urlencode(body) + response = requests.post(self.TOKEN_ENDPOINT, data=url_encoded_body, headers=headers) + response_body = response.json() + + if response.status_code != 200: + logger.error( + event=f"[GITHUB Tool] Error refreshing token: {response_body}" + ) + return False + + token = response_body.get("access_token", None) + token_type = response_body.get("token_type", None) + refresh_token = response_body.get("refresh_token", "") + expires_in = response_body.get("expires_in", 31536000) + + if token is None: + logger.error( + event=f"[GITHUB Tool] Error retrieving auth token: {response_body}" + ) + return False + + existing_tool_auth = tool_auth_crud.get_tool_auth( + session, self.TOOL_ID, user_id + ) + tool_auth_crud.update_tool_auth( + session, + existing_tool_auth, + UpdateToolAuth( + user_id=user_id, + tool_id=self.TOOL_ID, + token_type=token_type, + encrypted_access_token=encrypt(token), + encrypted_refresh_token=encrypt(refresh_token), + expires_at=datetime.datetime.now() + + datetime.timedelta(seconds=expires_in), + ), + ) + + return True diff --git a/src/backend/tools/github/client.py b/src/backend/tools/github/client.py new file mode 100644 index 0000000000..8a91d47376 --- /dev/null +++ b/src/backend/tools/github/client.py @@ -0,0 +1,35 @@ +from github import Auth, Github + +from backend.tools.github.constants import SEARCH_LIMIT + + +class GithubClient: + def __init__(self, auth_token, search_limit=SEARCH_LIMIT): + auth = Auth.Token(auth_token) + self.client = Github(auth=auth, per_page=search_limit) + + def search_all(self, query: str): + code_results = self.search_code(query) + all_results = {"code": code_results} + return all_results + + + def search_code(self, query: str): + return self.client.search_code(query).get_page(0) + + def search_repositories(self, query: str): + return self.client.search_repositories(query).get_page(0) + + def search_pull_requests(self, query: str): + return self.client.search_issues(f"{query} is:pr").get_page(0) + + def search_commits(self, query: str): + return self.client.search_commits(query).get_page(0) + + def get_user(self): + return self.client.get_user() + + def get_user_repositories(self, user): + return user.get_repos(type="all") + + diff --git a/src/backend/tools/github/constants.py b/src/backend/tools/github/constants.py new file mode 100644 index 0000000000..f033ed4584 --- /dev/null +++ b/src/backend/tools/github/constants.py @@ -0,0 +1,2 @@ +SEARCH_LIMIT = 10 +GITHUB_TOOL_ID = "github" diff --git a/src/backend/tools/github/tool.py b/src/backend/tools/github/tool.py new file mode 100644 index 0000000000..e218638626 --- /dev/null +++ b/src/backend/tools/github/tool.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, List, Union + +from backend.config.settings import Settings +from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.services.logger.utils import LoggerFactory +from backend.tools.base import BaseTool, ToolError +from backend.tools.github.auth import GithubAuth +from backend.tools.github.constants import GITHUB_TOOL_ID, SEARCH_LIMIT +from backend.tools.github.utils import get_github_service + +logger = LoggerFactory().get_logger() + + +class GithubTool(BaseTool): + """ + Tool that searches Github for repositories based on a query. + """ + ID = GITHUB_TOOL_ID + CLIENT_ID = Settings().get('tools.github.client_id') + CLIENT_SECRET = Settings().get('tools.github.client_secret') + DEFAULT_REPOS = Settings().get('tools.github.default_repos') + + @classmethod + def is_available(cls) -> bool: + return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None + + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Github", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search Github.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + auth_implementation=GithubAuth, + should_return_token=False, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from Github.", + ) # type: ignore + + @classmethod + def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: + message = "[Github] Tool Error: {}".format(str(error)) + + if error: + session = kwargs["session"] + user_id = kwargs["user_id"] + tool_auth_crud.delete_tool_auth( + db=session, user_id=user_id, tool_id=GITHUB_TOOL_ID + ) + + logger.error( + event="[Github] Auth token error: Please refresh the page and re-authenticate." + ) + raise Exception(message) + + async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> Union[List[Dict[str, Any]], ToolError]: + user_id = kwargs.get("user_id", "") + query = parameters.get("query", "") + + # Search Slack + github_service = get_github_service(user_id=user_id, default_repos=self.DEFAULT_REPOS, + search_limit=SEARCH_LIMIT) + try: + all_results = github_service.search(query=query) + results = github_service.transform_response(all_results) + except Exception as e: + return self.get_tool_error(details=str(e)) + + if not results: + return self.get_no_results_error() + + return results diff --git a/src/backend/tools/github/utils.py b/src/backend/tools/github/utils.py new file mode 100644 index 0000000000..39da1dd763 --- /dev/null +++ b/src/backend/tools/github/utils.py @@ -0,0 +1,72 @@ +from backend.database_models.database import get_session +from backend.services.logger.utils import LoggerFactory +from backend.tools.base import ToolAuthException +from backend.tools.github.auth import GithubAuth +from backend.tools.github.client import GithubClient +from backend.tools.github.constants import GITHUB_TOOL_ID, SEARCH_LIMIT + +logger = LoggerFactory().get_logger() + + +class GithubService: + def __init__(self, user_id: str, auth_token: str, default_repos: list[str], search_limit=SEARCH_LIMIT): + self.user_id = user_id + self.auth_token = auth_token + self.client = GithubClient(auth_token=auth_token, search_limit=search_limit) + self.github_user = self.client.get_user() + self.repositories = self.client.get_user_repositories(self.github_user) + if default_repos: + self.repositories = [repo for repo in self.repositories if repo.full_name in default_repos] + + + @staticmethod + def prepare_repo_query(query: str, repo): + if repo.fork: + query += " fork:true" + return f"{query} repo:{repo.full_name}" + + @staticmethod + def _extract_code_data(code): + return { + "title": code.path, + "text": code.decoded_content.decode("utf-8"), + "url": code.html_url, + "type": "code", + } + + def search(self, query: str): + results = {"code": []} + for repo in self.repositories: + prepared_query = self.prepare_repo_query(query, repo) + repo_results = self.client.search_all(query=prepared_query) + results["code"].extend(repo_results["code"]) + return results + + def transform_response(self, response): + results = [] + for code in response["code"]: + results.append(self._extract_code_data(code)) + return results + + +def get_github_service(user_id: str, default_repos: list[str], search_limit=SEARCH_LIMIT) -> GithubService: + github_auth = GithubAuth() + session = next(get_session()) + if github_auth.is_auth_required(session, user_id=user_id): + session.close() + raise ToolAuthException( + "Github Tool auth Error: Agent creator credentials need to re-authenticate", + GITHUB_TOOL_ID, + ) + + auth_token = github_auth.get_token(session=session, user_id=user_id) + if auth_token is None: + session.close() + raise Exception("Github Tool Error: No credentials found") + + service = GithubService(user_id=user_id, auth_token=auth_token, default_repos=default_repos, search_limit=search_limit) + session.close() + return service + + + diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index 20c9616374..9fde976d0a 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -41,6 +41,7 @@ def get_tool_definition(cls) -> ToolDefinition: is_visible=True, is_available=cls.is_available(), auth_implementation=SlackAuth, + should_return_token=False, error_message=cls.generate_error_message(), category=ToolCategory.DataLoader, description="Returns a list of relevant document snippets from slack.", diff --git a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx index 094b3d28ce..b2acc46172 100644 --- a/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/settings/Settings.tsx @@ -8,35 +8,24 @@ import { Button, DarkModeToggle, Icon, + IconName, ShowCitationsToggle, ShowStepsToggle, Tabs, Text, } from '@/components/UI'; -import { TOOL_GMAIL_ID, TOOL_SLACK_ID } from '@/constants'; +import { TOOL_GITHUB_ID, TOOL_GMAIL_ID, TOOL_SLACK_ID } from '@/constants'; import { useDeleteAuthTool, useListTools, useNotify } from '@/hooks'; import { cn, getToolAuthUrl } from '@/utils'; -const tabs = [ -
- - Connections -
, -
- - Appearance -
, -
- - Advanced -
, -
- - Profile -
, +const tabs: { key: string; icon: IconName; label: string }[] = [ + { key: 'connections', icon: 'users-three', label: 'Connections' }, + { key: 'appearance', icon: 'sun', label: 'Appearance' }, + { key: 'advanced', icon: 'settings', label: 'Advanced' }, + { key: 'profile', icon: 'profile', label: 'Profile' }, ]; -export const Settings = () => { +const Settings = () => { const [selectedTabIndex, setSelectedTabIndex] = useState(0); return ( @@ -57,7 +46,12 @@ export const Settings = () => {
( +
+ + {tab.label} +
+ ))} selectedIndex={selectedTabIndex} onChange={setSelectedTabIndex} tabGroupClassName="h-full" @@ -88,119 +82,108 @@ const Connections = () => ( + ); -const Appearance = () => { - return ( - - - Mode - - - - ); -}; +const Appearance = () => ( + + + Mode + + + +); -const Advanced = () => { - return ( - - - Advanced - - - - - ); -}; +const Advanced = () => ( + + + Advanced + + + + +); -const Profile = () => { - return ( - - - User Profile - -
- - ); -}; - -const GmailConnection = () => { - const { data } = useListTools(); - const { mutateAsync: deleteAuthTool } = useDeleteAuthTool(); - const notify = useNotify(); - const gmailTool = data?.find((tool) => tool.name === TOOL_GMAIL_ID); +const GoogleDriveConnection = () => ( + +); - if (!gmailTool) { - return null; - } +const SlackConnection = () => ( + +); - const handleDeleteAuthTool = async () => { - try { - await deleteAuthTool(gmailTool.name!); - } catch (e) { - notify.error('Failed to delete Gmail connection'); - } - }; +const GmailConnection = () => ( + +); - const isGmailConnected = !(gmailTool.is_auth_required ?? false); +const GithubConnection = () => ( + +); - return ( -
-
-
- - Gmail -
- -
- Connect to Gmail -
- {isGmailConnected ? ( -
-
-
-
- ) : ( -
-
- ); -}; +export { Settings }; diff --git a/src/interfaces/assistants_web/src/assets/icons/Github.tsx b/src/interfaces/assistants_web/src/assets/icons/Github.tsx new file mode 100644 index 0000000000..750afa19bb --- /dev/null +++ b/src/interfaces/assistants_web/src/assets/icons/Github.tsx @@ -0,0 +1,21 @@ +import * as React from 'react'; +import { SVGProps } from 'react'; + +import { cn } from '@/utils'; + +export const Github: React.FC> = ({ className, ...props }) => ( + + + +); diff --git a/src/interfaces/assistants_web/src/components/UI/Icon.tsx b/src/interfaces/assistants_web/src/components/UI/Icon.tsx index 5f0a1c6c9a..1a00622dc8 100644 --- a/src/interfaces/assistants_web/src/components/UI/Icon.tsx +++ b/src/interfaces/assistants_web/src/components/UI/Icon.tsx @@ -63,6 +63,7 @@ import { Warning, Web, } from '@/assets/icons'; +import { Github } from '@/assets/icons/Github'; import { cn } from '@/utils'; export const IconList = [ @@ -128,6 +129,7 @@ export const IconList = [ 'web', 'slack', 'gmail', + 'github', ] as const; export type IconName = (typeof IconList)[number]; @@ -482,6 +484,11 @@ const getIcon = (name: IconName, kind: IconKind): React.ReactNode => { ), + ['github']: ( + + + + ), ['hot-keys']: ( diff --git a/src/interfaces/assistants_web/src/constants/tools.ts b/src/interfaces/assistants_web/src/constants/tools.ts index 005492968b..029433bdd2 100644 --- a/src/interfaces/assistants_web/src/constants/tools.ts +++ b/src/interfaces/assistants_web/src/constants/tools.ts @@ -14,6 +14,7 @@ export const TOOL_WEB_SCRAPE_ID = 'web_scrape'; export const TOOL_GOOGLE_DRIVE_ID = 'google_drive'; export const TOOL_SLACK_ID = 'slack'; export const TOOL_GMAIL_ID = 'gmail'; +export const TOOL_GITHUB_ID = 'github'; export const BACKGROUND_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID]; @@ -30,4 +31,5 @@ export const TOOL_ID_TO_DISPLAY_INFO: { [id: string]: { icon: IconName } } = { [TOOL_READ_DOCUMENT_ID]: { icon: 'desktop' }, [TOOL_SLACK_ID]: { icon: 'slack' }, [TOOL_GMAIL_ID]: { icon: 'gmail' }, + [TOOL_GITHUB_ID]: { icon: 'github' }, };