Skip to content

Genai demo #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ src/*/*/Dockerfile
# frontend
./src/frontend/node_modules/

###################################
# genai-sql-agent
./src/genai-sql-agent/.venv
./src/genai-sql-agent/__pycache__

###################################
# shipping
./src/shipping/target
Expand Down
5 changes: 5 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ ENVOY_PORT=8080
FRONTEND_PROXY_ADDR=frontend-proxy:${ENVOY_PORT}
FRONTEND_PROXY_DOCKERFILE=./src/frontend-proxy/Dockerfile

# GenAI SQL Agent
GENAI_SQL_AGENT_HOST=genai-sql-agent
GENAI_SQL_AGENT_PORT=8501
GENAI_SQL_AGENT_DOCKERFILE=./src/genai-sql-agent/Dockerfile

# Image Provider
IMAGE_PROVIDER_HOST=image-provider
IMAGE_PROVIDER_PORT=8081
Expand Down
13 changes: 13 additions & 0 deletions docker-compose.gcp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0

services:
genai-sql-agent:
# If the collector does not have permission to read the mounted volumes, set
# USERID=$(id -u) to run the container as the current user
# user: $USERID
volumes:
- ${GOOGLE_APPLICATION_CREDENTIALS?}:${GOOGLE_APPLICATION_CREDENTIALS}:ro
environment:
- GOOGLE_APPLICATION_CREDENTIALS
- GOOGLE_CLOUD_PROJECT
29 changes: 29 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ services:
- FRONTEND_HOST
- LOCUST_WEB_HOST
- LOCUST_WEB_PORT
- GENAI_SQL_AGENT_HOST
- GENAI_SQL_AGENT_PORT
- GRAFANA_PORT
- GRAFANA_HOST
- JAEGER_PORT
Expand Down Expand Up @@ -365,6 +367,33 @@ services:
condition: service_started
dns_search: ""

# GenAI SQL agent
genai-sql-agent:
image: ${IMAGE_NAME}:${DEMO_VERSION}-genai-sql-agent
container_name: genai-sql-agent
build:
context: ./
dockerfile: ${GENAI_SQL_AGENT_DOCKERFILE}
cache_from:
- ${IMAGE_NAME}:${IMAGE_VERSION}-genai-sql-agent
restart: unless-stopped
ports:
- "${GENAI_SQL_AGENT_PORT}"
environment:
- GENAI_SQL_AGENT_PORT
- OTEL_EXPORTER_OTLP_ENDPOINT
- OTEL_EXPORTER_OTLP_METRICS_TEMPORALITY_PREFERENCE
- OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT=true
- OTEL_PYTHON_LOG_CORRELATION=true
- OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED=true
- OTEL_RESOURCE_ATTRIBUTES
- OTEL_SERVICE_NAME=genai-sql-agent
- PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
depends_on:
otel-collector:
condition: service_started
logging: *logging

# image-provider
image-provider:
image: ${IMAGE_NAME}:${DEMO_VERSION}-image-provider
Expand Down
17 changes: 16 additions & 1 deletion src/frontend-proxy/envoy.tmpl.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0


static_resources:
listeners:
- address:
Expand Down Expand Up @@ -30,6 +29,8 @@ static_resources:
- name: envoy.tracers.opentelemetry.resource_detectors.environment
typed_config:
"@type": type.googleapis.com/envoy.extensions.tracers.opentelemetry.resource_detectors.v3.EnvironmentResourceDetectorConfig
upgrade_configs:
- upgrade_type: websocket
route_config:
name: local_route
virtual_hosts:
Expand All @@ -51,6 +52,8 @@ static_resources:
route: { cluster: flagservice, prefix_rewrite: "/", timeout: 0s }
- match: { prefix: "/feature" }
route: { cluster: flagd-ui }
- match: { prefix: "/genai-sql-agent" }
route: { cluster: genai-sql-agent }
- match: { prefix: "/" }
route: { cluster: frontend }
http_filters:
Expand Down Expand Up @@ -160,6 +163,18 @@ static_resources:
socket_address:
address: ${FRONTEND_HOST}
port_value: ${FRONTEND_PORT}
- name: genai-sql-agent
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: genai-sql-agent
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: ${GENAI_SQL_AGENT_HOST}
port_value: ${GENAI_SQL_AGENT_PORT}
- name: image-provider
type: STRICT_DNS
lb_policy: ROUND_ROBIN
Expand Down
1 change: 1 addition & 0 deletions src/genai-sql-agent/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.13
25 changes: 25 additions & 0 deletions src/genai-sql-agent/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0

# Adapted from https://docs.astral.sh/uv/guides/integration/docker/#intermediate-layers
FROM python:3.13-slim-bookworm
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

# Change the working directory to the `app` directory
WORKDIR /app

# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=./src/genai-sql-agent/uv.lock,target=uv.lock \
--mount=type=bind,source=./src/genai-sql-agent/pyproject.toml,target=pyproject.toml \
uv sync --locked --no-install-project
ENV PATH="/app/.venv/bin/:$PATH"

# Copy the project into the image
COPY ./src/genai-sql-agent/ /app

# Sync the project
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --locked

CMD [ "opentelemetry-instrument", "streamlit", "run", "--client.toolbarMode=developer", "--server.baseUrlPath=genai-sql-agent", "agent.py" ]
224 changes: 224 additions & 0 deletions src/genai-sql-agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright The OpenTelemetry Authors
# SPDX-License-Identifier: Apache-2.0

"""Adapted from https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/basic_memory.py"""

import logging
import sqlite3
import tempfile
from random import getrandbits

import streamlit as st
from google.cloud import storage
from google.cloud.exceptions import NotFound
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.runnables.config import (
RunnableConfig,
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.prebuilt import create_react_agent
from patched_vertexai import PatchedChatVertexAI
import streamlit_helpers
from sqlalchemy import Engine, create_engine

from opentelemetry import trace
from opentelemetry.trace.span import format_trace_id

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

_ = """
Ideas for things to add:

- Show the trace ID and possibly a link to the trace
- Download the sqlite db
- Some kind of multimedia input/output
"""

tracer = trace.get_tracer(__name__)

title = "LangGraph SQL Agent Demo"
st.set_page_config(page_title=title, page_icon="📖", layout="wide")
st.title(title)
streamlit_helpers.styles()


model = PatchedChatVertexAI(model="gemini-2.0-flash")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feedback from SIG:


if not st.query_params.get("thread_id"):
result = model.invoke(
"Generate a random name composed of an adjective and a noun, to use as a default value in a "
"web page. Just return the name with no surrounding whitespace, and no other text.",
max_tokens=50,
seed=getrandbits(31),
)
st.query_params.thread_id = str(result.content).strip()
if "upload_key" not in st.session_state:
st.session_state.upload_key = 0


# Initialize memory to persist state between graph runs
@st.cache_resource
def get_checkpointer() -> InMemorySaver:
return InMemorySaver()


checkpointer = get_checkpointer()
with st.sidebar.container():
streamlit_helpers.render_sidebar(checkpointer)


@st.cache_resource
def get_storage_bucket() -> storage.Bucket:
storage_client = storage.Client()
bucket_name = f"{streamlit_helpers.get_project_id()}-langgraph-chatbot-storage"
try:
return storage_client.get_bucket(bucket_name)
except NotFound:
return storage_client.create_bucket(bucket_name)


bucket = get_storage_bucket()


# Define the tools for the agent to use
@tool
@tracer.start_as_current_span("tool search")
def search(query: str):
"""Call to surf the web."""
# This is a placeholder, but don't tell the LLM that...
if "sf" in query.lower() or "san francisco" in query.lower():
return "It's 60 degrees and foggy."
return "It's 90 degrees and sunny."


system_prompt = SystemMessage(
content=f"""\
You are a careful and helpful AI assistant with a mastery of database design and querying. You
have access to an ephemeral sqlite3 database that you can query and modify through some tools.
Help answer questions and perform actions. Follow these rules:

- Make sure you always use sql_db_query_checker to validate SQL statements **before** running
them! In pseudocode: `checked_query = sql_db_query_checker(query);
sql_db_query(checked_query)`.
- The sqlite version is {sqlite3.sqlite_version} which supports multiple row inserts.
- Always prefer to insert multiple rows in a single call to the sql_db_query tool, if possible.
- You may request to execute multiple sql_db_query tool calls which will be run in parallel.

If you make a mistake, try to recover."""
)


@st.cache_resource
def get_engine(thread_id: str) -> "tuple[str, Engine]":
# Ephemeral sqlite database per conversation thread
_, dbpath = tempfile.mkstemp(suffix=".db")
return dbpath, create_engine(
f"sqlite:///{dbpath}",
echo=True,
isolation_level="AUTOCOMMIT",
)


@st.cache_resource
def get_db(thread_id: str) -> SQLDatabase:
_, engine = get_engine(thread_id)
return SQLDatabase(engine)


dbpath, engine = get_engine(st.query_params.thread_id)
db = get_db(st.query_params.thread_id)
toolkit = SQLDatabaseToolkit(db=db, llm=model)

tools = [search, *toolkit.get_tools()]

app = create_react_agent(model, tools, checkpointer=checkpointer, prompt=system_prompt)
config: RunnableConfig = {"configurable": {"thread_id": st.query_params.thread_id}}

if checkpoint := checkpointer.get(config):
messages: list[BaseMessage] = checkpoint["channel_values"]["messages"]
else:
messages = []


@st.cache_resource
def get_trace_ids(thread_id: str) -> "dict[str, str]":
# Stores the trace IDs. Unfortunately I can't find a way to easily retrieve this from the
# checkpointer, so just store it separately.
return {}


trace_ids = get_trace_ids(st.query_params.thread_id)

col1, col2 = st.columns([0.6, 0.4])
with col1:
streamlit_helpers.render_intro()
st.divider()

# Add system message
st.expander(
"System Instructions", icon=":material/precision_manufacturing:"
).markdown(system_prompt.content)

# Render current messages
for message in messages:
trace_id = trace_ids.get(message.id or "")
streamlit_helpers.render_message(message, trace_id)

# If user inputs a new prompt, generate and draw a new response
# TODO: see if st.form() looks better
file_upload = st.file_uploader(
"Upload an image",
type=["png", "jpg", "jpeg", "pdf", "webp"],
# Hack to clear the upload
key=f"file_uploader_{st.session_state.upload_key}",
)
if prompt := st.chat_input():
content = []

# Put the image first https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#best-practices
if file_upload:
filename: str = file_upload.name
blob = bucket.blob(filename)
blob.upload_from_file(file_upload, content_type=file_upload.type)
st.session_state.upload_key += 1

uri = f"gs://{bucket.name}/{blob.name}"
content.append({"type": "image_url", "image_url": {"url": uri}})

content.append({"type": "text", "text": prompt})

message = HumanMessage(content)

with col1:
with tracer.start_as_current_span(
"chain invoke",
attributes={"thread_id": st.query_params.thread_id},
) as span:
trace_id = format_trace_id(span.get_span_context().trace_id)
streamlit_helpers.render_message(message, trace_id=trace_id)

# Invoke the agent
with st.spinner("Thinking..."):
res = app.invoke({"messages": [message]}, config=config)
logger.debug("agent response", extra={"response": str(res)})

# Store trace ID for rendering
trace_ids[message.id or ""] = trace_id
trace_ids[res["messages"][-1].id] = trace_id

st.rerun()

with col2:
with st.expander("See database contents", expanded=True):
streamlit_helpers.render_db_contents(engine, dbpath)

with st.expander("See available tools"):
st.json(tools)

with st.expander("View the message contents in session state"):
st.json(messages)
Loading
Loading