|
| 1 | +import time |
| 2 | +import os |
| 3 | +import joblib |
| 4 | +import streamlit as st |
| 5 | +import google.generativeai as genai |
| 6 | +from dotenv import load_dotenv |
| 7 | + |
| 8 | +# Load environment variables |
| 9 | +load_dotenv() |
| 10 | +GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') |
| 11 | +genai.configure(api_key=os.environ.get('GEMINI_API_KEY')) |
| 12 | + |
| 13 | +# Constants |
| 14 | +MODEL_ROLE = 'ai' |
| 15 | +AI_AVATAR_ICON = '👄' |
| 16 | +DATA_DIR = 'data/' |
| 17 | + |
| 18 | + |
| 19 | +def history_chatbot(): |
| 20 | + # Ensure the data/ directory exists |
| 21 | + os.makedirs(DATA_DIR, exist_ok=True) |
| 22 | + |
| 23 | + # Generate a new chat ID |
| 24 | + new_chat_id = f'{time.time()}' |
| 25 | + |
| 26 | + # Load past chats if available |
| 27 | + try: |
| 28 | + past_chats = joblib.load(os.path.join(DATA_DIR, 'past_chats_list')) |
| 29 | + except FileNotFoundError: |
| 30 | + past_chats = {} |
| 31 | + |
| 32 | + # Sidebar for past chats |
| 33 | + with st.sidebar: |
| 34 | + st.write('# Past Chats') |
| 35 | + if 'chat_id' not in st.session_state: |
| 36 | + st.session_state.chat_id = st.selectbox( |
| 37 | + label='Pick a past chat', |
| 38 | + options=[new_chat_id] + list(past_chats.keys()), |
| 39 | + format_func=lambda x: past_chats.get(x, 'New Chat'), |
| 40 | + placeholder='_' |
| 41 | + ) |
| 42 | + else: |
| 43 | + st.session_state.chat_id = st.selectbox( |
| 44 | + label='Pick a past chat', |
| 45 | + options=[new_chat_id, st.session_state.chat_id] + list(past_chats.keys()), |
| 46 | + index=1, |
| 47 | + format_func=lambda x: past_chats.get(x, 'New Chat' if x != st.session_state.chat_id else st.session_state.chat_title), |
| 48 | + placeholder='_' |
| 49 | + ) |
| 50 | + st.session_state.chat_title = f'ChatSession-{st.session_state.chat_id}' |
| 51 | + |
| 52 | + # Load chat history if available |
| 53 | + try: |
| 54 | + st.session_state.messages = joblib.load(os.path.join(DATA_DIR, f'{st.session_state.chat_id}-st_messages')) |
| 55 | + st.session_state.gemini_history = joblib.load(os.path.join(DATA_DIR, f'{st.session_state.chat_id}-gemini_messages')) |
| 56 | + print('Loaded existing chat history') |
| 57 | + except FileNotFoundError: |
| 58 | + st.session_state.messages = [] |
| 59 | + st.session_state.gemini_history = [] |
| 60 | + print('Initialized new chat history') |
| 61 | + |
| 62 | + # Configure the AI model |
| 63 | + st.session_state.model = genai.GenerativeModel('gemini-pro') |
| 64 | + st.session_state.chat = st.session_state.model.start_chat(history=st.session_state.gemini_history) |
| 65 | + |
| 66 | + # Display past messages |
| 67 | + for message in st.session_state.messages: |
| 68 | + with st.chat_message(name=message['role'], avatar=message.get('avatar')): |
| 69 | + st.markdown(message['content']) |
| 70 | + |
| 71 | + # Handle user input |
| 72 | + if prompt := st.chat_input('Ask Alwrity...'): |
| 73 | + if st.session_state.chat_id not in past_chats: |
| 74 | + past_chats[st.session_state.chat_id] = st.session_state.chat_title |
| 75 | + joblib.dump(past_chats, os.path.join(DATA_DIR, 'past_chats_list')) |
| 76 | + |
| 77 | + # Display and save user message |
| 78 | + with st.chat_message('user'): |
| 79 | + st.markdown(prompt) |
| 80 | + st.session_state.messages.append({'role': 'user', 'content': prompt}) |
| 81 | + |
| 82 | + # Send message to AI and stream the response |
| 83 | + response = st.session_state.chat.send_message(prompt, stream=True) |
| 84 | + full_response = '' |
| 85 | + with st.chat_message(name=MODEL_ROLE, avatar=AI_AVATAR_ICON): |
| 86 | + message_placeholder = st.empty() |
| 87 | + for chunk in response: |
| 88 | + for ch in chunk.text.split(' '): |
| 89 | + full_response += ch + ' ' |
| 90 | + time.sleep(0.05) |
| 91 | + message_placeholder.write(full_response + '▌') |
| 92 | + message_placeholder.write(full_response) |
| 93 | + |
| 94 | + # Save the AI response |
| 95 | + st.session_state.messages.append({ |
| 96 | + 'role': MODEL_ROLE, |
| 97 | + 'content': full_response, |
| 98 | + 'avatar': AI_AVATAR_ICON |
| 99 | + }) |
| 100 | + st.session_state.gemini_history = st.session_state.chat.history |
| 101 | + |
| 102 | + # Persist chat history to disk |
| 103 | + joblib.dump(st.session_state.messages, os.path.join(DATA_DIR, f'{st.session_state.chat_id}-st_messages')) |
| 104 | + joblib.dump(st.session_state.gemini_history, os.path.join(DATA_DIR, f'{st.session_state.chat_id}-gemini_messages')) |
0 commit comments