Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1dc63ab
wip
tianjing-li May 8, 2024
1066962
wip google oauth
tianjing-li May 14, 2024
45bbb90
Add auth code
tianjing-li May 14, 2024
ff9a877
Add Google redirect logic
tianjing-li May 15, 2024
06ca28e
Change class name
tianjing-li May 15, 2024
6185e6e
Merge branch 'main' into add-google-oauth
tianjing-li May 29, 2024
6e249fe
pr comments
tianjing-li May 29, 2024
112ef51
fix tests
tianjing-li May 29, 2024
a48333d
Merge branch 'main' into add-google-oauth
tianjing-li May 29, 2024
91c5d9c
Add JWT logic for basic auth + oauth, move router dependencies to cre…
tianjing-li May 31, 2024
eaa710f
Revert makefile
tianjing-li May 31, 2024
7c9c88a
revert compose
tianjing-li May 31, 2024
0474788
Merge conflict fixes
tianjing-li May 31, 2024
a2a8f04
disable auth by default
tianjing-li Jun 4, 2024
3aeb1f3
Merge branch 'main' into add-jwt
tianjing-li Jun 4, 2024
f9d20f3
Add header logic
tianjing-li Jun 4, 2024
f7e028b
Add header logic everywhere
tianjing-li Jun 4, 2024
576299a
Merge branch 'add-jwt' of https://github.com/cohere-ai/cohere-toolkit…
tianjing-li Jun 4, 2024
305db8a
remove basic auth
tianjing-li Jun 4, 2024
5c5229e
add sentencepiece
tianjing-li Jun 4, 2024
082c358
update deps
tianjing-li Jun 4, 2024
a0f79c0
fix freezegun test issue
tianjing-li Jun 4, 2024
21e08f2
Merge branch 'main' into add-jwt
tianjing-li Jun 5, 2024
7a9c36f
Small fixes to OAuth
tianjing-li Jun 5, 2024
cf007de
Merge branch 'add-jwt' of https://github.com/cohere-ai/cohere-toolkit…
tianjing-li Jun 5, 2024
bbe6d84
pr comments
tianjing-li Jun 5, 2024
ad1051e
format
tianjing-li Jun 6, 2024
3b0e845
Merge branch 'main' into add-jwt
tianjing-li Jun 6, 2024
abca2bc
merge main
tianjing-li Jun 6, 2024
c26af9a
Merge branch 'add-jwt' of https://github.com/cohere-ai/cohere-toolkit…
tianjing-li Jun 6, 2024
01e4f73
small fix
tianjing-li Jun 6, 2024
ae4847b
revert makefile
tianjing-li Jun 6, 2024
8dfc91a
Merge branch 'main' into add-jwt
tianjing-li Jun 6, 2024
e43c744
revert community deps
tianjing-li Jun 6, 2024
bd37798
Merge branch 'add-jwt' of https://github.com/cohere-ai/cohere-toolkit…
tianjing-li Jun 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ USE_EXPERIMENTAL_LANGCHAIN=False
# Community features
USE_COMMUNITY_FEATURES='True'

# For setting up authentication, see: docs/auth_guide.md
# Authentication session
SESSION_SECRET_KEY=<GENERATE_A_SECRET_KEY>
# For setting up authentication, see: docs/auth_guide.md
JWT_SECRET_KEY=

# Google OAuth
GOOGLE_CLIENT_ID=<GOOGLE_CLIENT_ID>
Expand Down
7 changes: 0 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,10 @@ reset-db:
setup:
poetry install --with setup,community --verbose
poetry run python3 src/backend/cli/main.py
win-setup:
poetry install --with setup,community --verbose
poetry run python src/backend/cli/main.py
lint:
poetry run black .
poetry run isort .
first-run:
make setup
make migrate
make dev
win-first-run:
make win-setup
make migrate
make dev
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ services:
build:
context: .
args:
INSTALL_COMMUNITY_DEPS: false
INSTALL_COMMUNITY_DEPS: true
dockerfile: ./src/backend/Dockerfile
develop:
watch:
Expand Down
10 changes: 10 additions & 0 deletions docs/auth_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ The list of implemented authentication strategies exist in `src/backend/services
- GoogleOAuth: requires setting up [Google OAuth 2.0](https://support.google.com/cloud/answer/6158849?hl=en). You will need to retrieve a client ID and client secret and set them as environment variables.

To enable one or more of these strategies, simply add them to the `ENABLED_AUTH_STRATEGIES` list in the configurations.

After enabling one or more strategies, you must create a secret key to be used to encrypt the JWT tokens generated by the backend and store it in the `JWT_SECRET_KEY` environment variable.

For testing use-cases, you can enter any string value.
For production use-cases, We recommend running the following in a local CLI to generate a random key:

```
import secrets
print(secrets.token_hex(32))
```
543 changes: 290 additions & 253 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pyjwt = "^2.8.0"

[tool.poetry.group.dev]
optional = true
Expand All @@ -45,6 +46,7 @@ pytest = "^7.1.2"
pytest-env = "^1.1.3"
pytest-cov = "^5.0.0"
factory-boy = "^3.3.0"
freezegun = "^1.5.1"
pre-commit = "^2.20.0"
ruff = "^0.0.94"
isort = "^5.12.0"
Expand Down
10 changes: 5 additions & 5 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger

logger = get_logger()


class CustomChat(BaseChat):
"""Custom chat flow not using integrations for models."""

logger = get_logger()

def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
Chat flow for custom models.
Expand All @@ -31,7 +31,7 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
# Choose the deployment model - validation already performed by request validator
deployment_model = get_deployment(kwargs.get("deployment_name"), **kwargs)
self.logger.info(f"Using deployment {deployment_model.__class__.__name__}")
logger.info(f"Using deployment {deployment_model.__class__.__name__}")

if len(chat_request.tools) > 0 and len(chat_request.documents) > 0:
raise HTTPException(
Expand Down Expand Up @@ -68,13 +68,13 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
queries = deployment_model.invoke_search_queries(
chat_request.message, chat_history
)
self.logger.info(f"Search queries generated: {queries}")
logger.info(f"Search queries generated: {queries}")

# Fetch Documents
retrievers = self.get_retrievers(
kwargs.get("file_paths", []), [tool.name for tool in chat_request.tools]
)
self.logger.info(
logger.info(
f"Using retrievers: {[retriever.__class__.__name__ for retriever in retrievers]}"
)

Expand Down
16 changes: 15 additions & 1 deletion src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from backend.services.auth import BasicAuthentication, GoogleOAuth

# Modify this to enable auth strategies.
# Add Auth strategy classes here to enable them
# Ex: [BasicAuthentication]
ENABLED_AUTH_STRATEGIES = []

# Define the mapping from Auth strategy name to class obj - does not need to be manually modified.
# During runtime, this will create an instance of each enabled strategy class.
# Ex: {"Basic": BasicAuthentication()}
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls() for cls in ENABLED_AUTH_STRATEGIES}


def is_authentication_enabled() -> bool:
"""
Check whether any form of authentication was enabled.

Returns:
bool: Whether authentication is enabled.
"""
if ENABLED_AUTH_STRATEGIES:
return True

return False
90 changes: 90 additions & 0 deletions src/backend/config/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from enum import StrEnum

from fastapi import Depends

from backend.database_models import get_session
from backend.services.auth.request_validators import validate_authorization
from backend.services.request_validators import (
validate_chat_request,
validate_user_header,
)


# Important! Any new routers must have a corresponding RouterName entry and Router dependencies
# mapping below. Make sure they use the correct ones depending on whether authentication is enabled or not.
class RouterName(StrEnum):
AUTH = "auth"
CHAT = "chat"
CONVERSATION = "conversation"
DEPLOYMENT = "deployment"
EXPERIMENTAL_FEATURES = "experimental_features"
TOOL = "tool"
USER = "user"


# Router dependency mappings
ROUTER_DEPENDENCIES = {
RouterName.AUTH: {
"default": [
Depends(get_session),
],
"auth": [
Depends(get_session),
],
},
RouterName.CHAT: {
"default": [
Depends(get_session),
Depends(validate_chat_request),
Depends(validate_user_header),
],
"auth": [
Depends(get_session),
Depends(validate_chat_request),
Depends(validate_authorization),
],
},
RouterName.CONVERSATION: {
"default": [
Depends(get_session),
Depends(validate_user_header),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
],
},
RouterName.DEPLOYMENT: {
"default": [
Depends(get_session),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
],
},
RouterName.EXPERIMENTAL_FEATURES: {
"default": [
Depends(get_session),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
],
},
RouterName.TOOL: {
"default": [],
"auth": [
Depends(validate_authorization),
],
},
RouterName.USER: {
"default": [
Depends(get_session),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
],
},
}
52 changes: 24 additions & 28 deletions src/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import logging
import os
from contextlib import asynccontextmanager

from alembic.command import upgrade
from alembic.config import Config
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.config.auth import is_authentication_enabled
from backend.config.routers import ROUTER_DEPENDENCIES
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
from backend.routers.conversation import router as conversation_router
Expand All @@ -35,14 +33,28 @@ async def lifespan(app: FastAPI):
def create_app():
app = FastAPI(lifespan=lifespan)

# Add routers
app.include_router(auth_router)
app.include_router(chat_router)
app.include_router(user_router)
app.include_router(conversation_router)
app.include_router(tool_router)
app.include_router(deployment_router)
app.include_router(experimental_feature_router)
routers = [
auth_router,
chat_router,
user_router,
conversation_router,
tool_router,
deployment_router,
experimental_feature_router,
]

# Dynamically set router dependencies
# These values must be set in config/routers.py
dependencies_type = "default"
if is_authentication_enabled():
dependencies_type = "auth"
for router in routers:
if getattr(router, "name", "") in ROUTER_DEPENDENCIES.keys():
router_name = router.name
dependencies = ROUTER_DEPENDENCIES[router_name][dependencies_type]
app.include_router(router, dependencies=dependencies)
else:
app.include_router(router)

# Add middleware
app.add_middleware(
Expand All @@ -54,22 +66,6 @@ def create_app():
)
app.add_middleware(LoggingMiddleware)

# Handle Authentication enabled
if ENABLED_AUTH_STRATEGY_MAPPING:
secret_key = os.environ.get("SESSION_SECRET_KEY", None)

if not secret_key:
raise ValueError(
"Missing SESSION_SECRET_KEY environment variable to enable Authentication."
)

# Handle User sessions and Auth
app.add_middleware(
SessionMiddleware,
secret_key=secret_key,
max_age=SESSION_EXPIRY,
)

return app


Expand Down
Loading