import streamlit as st from google.cloud import bigquery, storage import google.generativeai as gai from PIL import Image import io import uuid import re from utils import check_login, generate_uuid # Redirect to the authentication page if the user is not logged in. check_login(st) # Initialize the clients @st.cache_resource def initialize_clients(): storage_client = storage.Client() bigquery_client = bigquery.Client(st.secrets.project) bucket = storage_client.bucket(st.secrets.bucket) return storage_client, bigquery_client, bucket # Load clients storage_client, bigquery_client, bucket = initialize_clients() #Intialize Gemini models gai.configure(api_key=st.secrets.gai_api_key) text_model = gai.GenerativeModel('gemini-pro') image_model = gai.GenerativeModel('gemini-pro-vision') if 'show_new_chat_form' not in st.session_state: st.session_state.show_new_chat_form = False def show_new_chat_form(): st.session_state.show_new_chat_form = True #Retrieve lecture attributes from the dashboard through the session state lecture_id = st.session_state.lecture_id title = st.session_state.title summary = st.session_state.summary transcript = st.session_state.transcript thumbnail_path = f"thumbnails/{st.session_state.thumbnail_filename}" transcript_tags = text_model.generate_content( """ INPUT TRANSCRIPT: """ + transcript + """ Generate a set of tags related to the lecture transcript. """ ).text transcript_tags = set(transcript_tags.lower().split()) if 'current_chat_id' not in st.session_state: # Retrieve the newest chat with the highest sequence number if it exists query = """ SELECT chat_id FROM `{}.{}.Chats` WHERE lecture_id = '{}' ORDER BY sequence_number DESC LIMIT 1 """.format(st.secrets.project, st.secrets.dataset, lecture_id) job = bigquery_client.query(query) result = job.result() if result.total_rows > 0: newest_chat_id = list(result)[0]['chat_id'] st.session_state.current_chat_id = newest_chat_id else: st.session_state.current_chat_id = '' else: # Retrieve the newest chat with the highest sequence number if it exists query = """ SELECT chat_id FROM `{}.{}.Chats` WHERE lecture_id = '{}' ORDER BY sequence_number DESC LIMIT 1 """.format(st.secrets.project, st.secrets.dataset, lecture_id) job = bigquery_client.query(query) result = job.result() if result.total_rows > 0: newest_chat_id = list(result)[0]['chat_id'] st.session_state.current_chat_id = newest_chat_id else: st.session_state.current_chat_id = '' def insert_chat(lecture_id, chat_name): chat_id = str(generate_uuid()) # Generate a unique chat ID # Get the current sequence number for the new chat query = """ SELECT MAX(sequence_number) as max_sequence FROM `{}.{}.Chats` WHERE lecture_id = '{}' """.format(st.secrets.project, st.secrets.dataset, lecture_id) job = bigquery_client.query(query) result = job.result() if result.total_rows > 0: max_sequence = list(result)[0]['max_sequence'] sequence_number = int(max_sequence) + 1 if max_sequence is not None else 1 else: sequence_number = 1 # Insert the new chat into the Chats table query = """ INSERT INTO `{}.{}.Chats` (lecture_id, chat_id, chat_name, sequence_number) VALUES (@lecture_id, @chat_id, @chat_name, @sequence_number) """.format(st.secrets.project, st.secrets.dataset) job_config = bigquery.QueryJobConfig( query_parameters=[ bigquery.ScalarQueryParameter("lecture_id", "STRING", lecture_id), bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id), bigquery.ScalarQueryParameter("chat_name", "STRING", chat_name), bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number) ] ) job = bigquery_client.query(query, job_config=job_config) result = job.result() st.session_state.current_chat_id = chat_id def fetch_lecture_chats(lecture_id): query = """ SELECT chat_id, chat_name FROM `{}.{}.Chats` WHERE lecture_id = '{}' ORDER BY sequence_number DESC """.format(st.secrets.project, st.secrets.dataset, lecture_id) job = bigquery_client.query(query) results = job.result() return [(row['chat_id'], row['chat_name']) for row in results] def insert_chat_message(chat_id, user_query, gemini_response, image_url=None): # Get the current sequence number for the new message sequence_number = str(len(st.session_state.chat_history) + 1) # Insert the new message into the Chat_messages table if image_url: query = """ INSERT INTO `{}.{}.ChatMessages` (chat_id, user_query, gemini_response, sequence_number, image_url) VALUES (@chat_id, @user_query, @gemini_response, @sequence_number, @image_url) """.format(st.secrets.project, st.secrets.dataset) job_config = bigquery.QueryJobConfig( query_parameters=[ bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id), bigquery.ScalarQueryParameter("user_query", "STRING", user_query), bigquery.ScalarQueryParameter("gemini_response", "STRING", gemini_response), bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number), bigquery.ScalarQueryParameter("image_url", "STRING", image_url) ] ) else: query = """ INSERT INTO `{}.{}.ChatMessages` (chat_id, user_query, gemini_response, sequence_number) VALUES (@chat_id, @user_query, @gemini_response, @sequence_number) """.format(st.secrets.project, st.secrets.dataset) job_config = bigquery.QueryJobConfig( query_parameters=[ bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id), bigquery.ScalarQueryParameter("user_query", "STRING", user_query), bigquery.ScalarQueryParameter("gemini_response", "STRING", gemini_response), bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number) ] ) job = bigquery_client.query(query, job_config=job_config) result = job.result() def fetch_chat_messages(chat_id): query = """ SELECT user_query, gemini_response, image_url FROM `{}.{}.ChatMessages` WHERE chat_id = '{}' ORDER BY sequence_number """.format(st.secrets.project, st.secrets.dataset, chat_id) job = bigquery_client.query(query) results = job.result() # Convert the results into a list of dictionaries messages = [] for row in results: message = { 'user_query': row['user_query'], 'gemini_response': row['gemini_response'], 'image_url': row['image_url'] } messages.append(message) return messages if 'chat_history' not in st.session_state: if st.session_state.current_chat_id: st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id) else: st.session_state.chat_history = [] else: if st.session_state.current_chat_id: st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id) else: st.session_state.chat_history = [] def print_chat_history(messages): #Display messages into chat message blocks for i in range(len(messages)): if messages[i]['image_url']: with st.chat_message("user"): st.write(messages[i]['user_query']) # st.image(messages[i]['image_url']) st.image(Image.open(io.BytesIO(bucket.blob(messages[i]['image_url'].split(f"gs://{st.secrets.bucket}/")[1]).download_as_string()))) else: with st.chat_message("user"): st.write(messages[i]['user_query']) with st.chat_message("ai"): st.write(messages[i]['gemini_response']) def download_chat_history(messages): chat_history_text = "" for i in range(len(messages)): chat_history_text += "Message " + str(i+1) + ": \n\n" user_query = messages[i]['user_query'] gemini_response = messages[i]['gemini_response'] chat_history_text += "User: " + user_query + "\n\nChatbot: " + gemini_response + "\n\n" # Download chat history as a text file st.download_button( label="Download Chat History", data=chat_history_text, file_name="chat_history_" + st.session_state.current_chat_id + ".txt", mime="text/plain" ) def upload_image_to_gcs(image): with io.BytesIO() as output: image.save(output, format='JPEG') image_bytes = output.getvalue() #Generate unique file name for the image image_file_name = str(generate_uuid()) + '.jpeg' blob = bucket.blob(f"chat_images/{image_file_name}") blob.upload_from_string(image_bytes, content_type='image/jpeg') image_url = f"gs://{st.secrets.bucket}/chat_images/{image_file_name}" return image_url #Display lecture attributes to the user st.title("Chat Page") st.header(title) st.write(summary) st.image(Image.open(io.BytesIO(bucket.blob(thumbnail_path).download_as_string()))) st.write(transcript) # Display links to the dashboard and quiz page and list of all chats in a sidebar with st.sidebar: st.page_link("pages/dashboard.py", label="Dashboard") st.page_link("pages/quiz.py", label="Quiz Page") st.title("Chats") #Provide button to create a new chat st.button('Create new chat', on_click=show_new_chat_form) if st.session_state.show_new_chat_form: new_chat_name = st.text_input(label="Enter chat name") if new_chat_name: # Clear the placeholder after input is provided insert_chat(lecture_id, new_chat_name) st.session_state.show_new_chat_form = False st.session_state.chat_history = [] st.rerun() lecture_chats = fetch_lecture_chats(lecture_id) for chat_id, chat_name in lecture_chats: if st.button(chat_name): st.session_state.current_chat_id = chat_id st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id) st.subheader('Ask further questions and enhance your knowledge ') # Display chat history above the chat prompt chat_history_container = st.container() if st.session_state.chat_history: with chat_history_container: print_chat_history(st.session_state.chat_history) #Check if it is the first time the user loads the chat page from the given lecture in the dashboard if not st.session_state.current_chat_id: placeholder = st.empty() with placeholder.container(): chat_name = st.text_input(label="Enter chat name") if chat_name: # Clear the placeholder after input is provided placeholder.empty() insert_chat(lecture_id, chat_name) st.rerun() #At least one chat is created from the given lecture else: # Generate prompt suggestions for the user prompt_suggestions = text_model.generate_content( """ INPUT TRANSCRIPT: """ + transcript + """ Generate three possible questions that someone might ask to expand their knowledge on the topic of the given lecture transcipt. The sample questions must be short and relate to the content of the lecture transcript, and must be separated by new lines. """ ).text st.write("Prompt Suggestions: \n" + prompt_suggestions) if st.session_state.chat_history: download_chat_history(st.session_state.chat_history) # Receive image from user uploaded_image = st.file_uploader(label="Upload image (Optional)", type='jpeg') image = None if uploaded_image: image = Image.open(uploaded_image) # Check if the image uploaded by the user relates to the lecture topic image_tags = image_model.generate_content(contents=["Generate a set of tags related to the input image.", image]).text image_tags = set(image_tags.lower().split()) if not image_tags.intersection(transcript_tags): image = None st.warning("Please enter an image that relates to the topic of the lecture.") prompt = st.chat_input(placeholder="Enter prompt") if prompt: # Check if the user prompt relates to the lecture topic prompt_tags = set(prompt.lower().split()) if not prompt_tags.intersection(transcript_tags): st.warning("Please enter a prompt that relates to the topic of the lecture.") else: if image: response = image_model.generate_content(contents=[ """ You're an expert on the topic of the following lecture transcript. LECTURE TRANSCRIPT: """ + transcript + """ Generate a response to the following prompt and input image which may encompass broader aspects of the topic of the lecture. PROMPT: """ + prompt + """ """, image]).text print(response) image_url = upload_image_to_gcs(image) insert_chat_message(st.session_state.current_chat_id, prompt, response, image_url=image_url) st.rerun() else: response = text_model.generate_content( """ You're an expert on the topic of the following lecture transcript. LECTURE TRANSCRIPT: """ + transcript + """ Generate a response to the following prompt which may encompass broader aspects of the topic of the lecture. PROMPT: """ + prompt + """ """).text print(response) insert_chat_message(st.session_state.current_chat_id, prompt, response) st.rerun()