Class-Corder / pages / chat.py
chat.py
Raw
import streamlit as st
from google.cloud import bigquery, storage
import google.generativeai as gai
from PIL import Image
import io
import uuid
import re

from utils import check_login, generate_uuid  

# Redirect to the authentication page if the user is not logged in.
check_login(st)

# Initialize the clients
@st.cache_resource
def initialize_clients():
    storage_client = storage.Client()
    bigquery_client = bigquery.Client(st.secrets.project)
    bucket = storage_client.bucket(st.secrets.bucket)
    return storage_client, bigquery_client, bucket

# Load clients
storage_client, bigquery_client, bucket = initialize_clients()

#Intialize Gemini models
gai.configure(api_key=st.secrets.gai_api_key)
text_model = gai.GenerativeModel('gemini-pro')
image_model = gai.GenerativeModel('gemini-pro-vision')

if 'show_new_chat_form' not in st.session_state:
    st.session_state.show_new_chat_form = False

def show_new_chat_form():
    st.session_state.show_new_chat_form = True

#Retrieve lecture attributes from the dashboard through the session state
lecture_id = st.session_state.lecture_id
title = st.session_state.title
summary = st.session_state.summary
transcript = st.session_state.transcript
thumbnail_path = f"thumbnails/{st.session_state.thumbnail_filename}"
transcript_tags = text_model.generate_content(  
    """
       INPUT TRANSCRIPT: """ + transcript + """
       Generate a set of tags related to the lecture transcript. 
    """
).text
transcript_tags = set(transcript_tags.lower().split()) 

if 'current_chat_id' not in st.session_state:
    # Retrieve the newest chat with the highest sequence number if it exists
    query = """
        SELECT chat_id
        FROM `{}.{}.Chats`
        WHERE lecture_id = '{}'
        ORDER BY sequence_number DESC
        LIMIT 1
    """.format(st.secrets.project, st.secrets.dataset, lecture_id)
    job = bigquery_client.query(query)
    result = job.result()

    if result.total_rows > 0: 
        newest_chat_id = list(result)[0]['chat_id']
        st.session_state.current_chat_id = newest_chat_id
    else:
        st.session_state.current_chat_id = ''
else:
    # Retrieve the newest chat with the highest sequence number if it exists
    query = """
        SELECT chat_id
        FROM `{}.{}.Chats`
        WHERE lecture_id = '{}'
        ORDER BY sequence_number DESC
        LIMIT 1
    """.format(st.secrets.project, st.secrets.dataset, lecture_id)
    job = bigquery_client.query(query)
    result = job.result()

    if result.total_rows > 0: 
        newest_chat_id = list(result)[0]['chat_id']
        st.session_state.current_chat_id = newest_chat_id
    else:
        st.session_state.current_chat_id = ''

def insert_chat(lecture_id, chat_name):    
    chat_id = str(generate_uuid())  # Generate a unique chat ID
    # Get the current sequence number for the new chat
    query = """
        SELECT MAX(sequence_number) as max_sequence
        FROM `{}.{}.Chats`
        WHERE lecture_id = '{}'
    """.format(st.secrets.project, st.secrets.dataset, lecture_id)
    job = bigquery_client.query(query)
    result = job.result()
    if result.total_rows > 0:
        max_sequence = list(result)[0]['max_sequence']
        sequence_number = int(max_sequence) + 1 if max_sequence is not None else 1
    else:
        sequence_number = 1
    
    # Insert the new chat into the Chats table
    query = """
        INSERT INTO `{}.{}.Chats` (lecture_id, chat_id, chat_name, sequence_number)
        VALUES (@lecture_id, @chat_id, @chat_name, @sequence_number)
    """.format(st.secrets.project, st.secrets.dataset)
    job_config = bigquery.QueryJobConfig(
            query_parameters=[
                bigquery.ScalarQueryParameter("lecture_id", "STRING", lecture_id),
                bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id),
                bigquery.ScalarQueryParameter("chat_name", "STRING", chat_name),
                bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number)
                ]
            )
    job = bigquery_client.query(query, job_config=job_config)
    result = job.result()
    st.session_state.current_chat_id = chat_id 

def fetch_lecture_chats(lecture_id):
    query = """
        SELECT chat_id, chat_name
        FROM `{}.{}.Chats`
        WHERE lecture_id = '{}'
        ORDER BY sequence_number DESC
    """.format(st.secrets.project, st.secrets.dataset, lecture_id)
    job = bigquery_client.query(query)
    results = job.result()
    return [(row['chat_id'], row['chat_name']) for row in results]

def insert_chat_message(chat_id, user_query, gemini_response, image_url=None):
    # Get the current sequence number for the new message
    sequence_number = str(len(st.session_state.chat_history) + 1)
    
    # Insert the new message into the Chat_messages table
    if image_url:
        query = """
            INSERT INTO `{}.{}.ChatMessages` 
            (chat_id, user_query, gemini_response, sequence_number, image_url)
            VALUES (@chat_id, @user_query, @gemini_response, @sequence_number, @image_url)
        """.format(st.secrets.project, st.secrets.dataset)
        job_config = bigquery.QueryJobConfig(
            query_parameters=[
                bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id),
                bigquery.ScalarQueryParameter("user_query", "STRING", user_query),
                bigquery.ScalarQueryParameter("gemini_response", "STRING", gemini_response),
                bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number),
                bigquery.ScalarQueryParameter("image_url", "STRING", image_url)
                ]
            )
    else:
        query = """
            INSERT INTO `{}.{}.ChatMessages` 
            (chat_id, user_query, gemini_response, sequence_number)
            VALUES (@chat_id, @user_query, @gemini_response, @sequence_number)
        """.format(st.secrets.project, st.secrets.dataset)
        job_config = bigquery.QueryJobConfig(
            query_parameters=[
                bigquery.ScalarQueryParameter("chat_id", "STRING", chat_id),
                bigquery.ScalarQueryParameter("user_query", "STRING", user_query),
                bigquery.ScalarQueryParameter("gemini_response", "STRING", gemini_response),
                bigquery.ScalarQueryParameter("sequence_number", "STRING", sequence_number)
                ]
            )
    job = bigquery_client.query(query, job_config=job_config)
    result = job.result()

def fetch_chat_messages(chat_id):
    query = """
        SELECT user_query, gemini_response, image_url
        FROM `{}.{}.ChatMessages`
        WHERE chat_id = '{}'
        ORDER BY sequence_number
    """.format(st.secrets.project, st.secrets.dataset, chat_id)
    
    job = bigquery_client.query(query)
    results = job.result()

    # Convert the results into a list of dictionaries
    messages = []
    for row in results:
        message = {
            'user_query': row['user_query'],
            'gemini_response': row['gemini_response'],
            'image_url': row['image_url']
        }
        messages.append(message)
    return messages

if 'chat_history' not in st.session_state:
    if st.session_state.current_chat_id:
        st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id)
    else:
        st.session_state.chat_history = []
else:
    if st.session_state.current_chat_id:
        st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id)
    else:
        st.session_state.chat_history = []

def print_chat_history(messages):
    #Display messages into chat message blocks
    for i in range(len(messages)):
        if messages[i]['image_url']:
            with st.chat_message("user"):
                st.write(messages[i]['user_query'])
                # st.image(messages[i]['image_url'])
                st.image(Image.open(io.BytesIO(bucket.blob(messages[i]['image_url'].split(f"gs://{st.secrets.bucket}/")[1]).download_as_string())))
        else:
            with st.chat_message("user"):
                st.write(messages[i]['user_query'])
    
        with st.chat_message("ai"):
            st.write(messages[i]['gemini_response'])

def download_chat_history(messages):
    chat_history_text = ""
    for i in range(len(messages)):
        chat_history_text += "Message " + str(i+1) + ": \n\n"
        user_query = messages[i]['user_query']
        gemini_response = messages[i]['gemini_response']
        chat_history_text += "User: " + user_query + "\n\nChatbot: " + gemini_response + "\n\n"

    # Download chat history as a text file
    st.download_button(
        label="Download Chat History",
        data=chat_history_text,
        file_name="chat_history_" + st.session_state.current_chat_id + ".txt",
        mime="text/plain"
    )

def upload_image_to_gcs(image):
    with io.BytesIO() as output:
        image.save(output, format='JPEG')
        image_bytes = output.getvalue()
    #Generate unique file name for the image
    image_file_name = str(generate_uuid()) + '.jpeg'
    blob = bucket.blob(f"chat_images/{image_file_name}")
    blob.upload_from_string(image_bytes, content_type='image/jpeg')
    image_url = f"gs://{st.secrets.bucket}/chat_images/{image_file_name}"
    return image_url

#Display lecture attributes to the user
st.title("Chat Page")
st.header(title)
st.write(summary)
st.image(Image.open(io.BytesIO(bucket.blob(thumbnail_path).download_as_string())))
st.write(transcript)

# Display links to the dashboard and quiz page and list of all chats in a sidebar
with st.sidebar:
    st.page_link("pages/dashboard.py", label="Dashboard")
    st.page_link("pages/quiz.py", label="Quiz Page")
    st.title("Chats")
    #Provide button to create a new chat
    st.button('Create new chat', on_click=show_new_chat_form)
    if st.session_state.show_new_chat_form:  
        new_chat_name = st.text_input(label="Enter chat name")
        if new_chat_name:
            # Clear the placeholder after input is provided
            insert_chat(lecture_id, new_chat_name)
            st.session_state.show_new_chat_form = False
            st.session_state.chat_history = []
            st.rerun()
    
    lecture_chats = fetch_lecture_chats(lecture_id)
    for chat_id, chat_name in lecture_chats:
        if st.button(chat_name):
            st.session_state.current_chat_id = chat_id
            st.session_state.chat_history = fetch_chat_messages(st.session_state.current_chat_id)

st.subheader('Ask further questions and enhance your knowledge ')

# Display chat history above the chat prompt
chat_history_container = st.container()
if st.session_state.chat_history:
    with chat_history_container:
        print_chat_history(st.session_state.chat_history)

#Check if it is the first time the user loads the chat page from the given lecture in the dashboard
if not st.session_state.current_chat_id:
    placeholder = st.empty()
    with placeholder.container():
        chat_name = st.text_input(label="Enter chat name")
        if chat_name:
            # Clear the placeholder after input is provided
            placeholder.empty()
            insert_chat(lecture_id, chat_name)
            st.rerun()
#At least one chat is created from the given lecture 
else:
    # Generate prompt suggestions for the user 
    prompt_suggestions = text_model.generate_content(
        """
        INPUT TRANSCRIPT: """ + transcript + """
        Generate three possible questions that someone might ask to expand their knowledge on the topic of the given lecture transcipt. The sample questions must be short and relate to the content of the lecture transcript, and must be separated by new lines.
        """
    ).text
    st.write("Prompt Suggestions: \n" + prompt_suggestions)

    if st.session_state.chat_history:
        download_chat_history(st.session_state.chat_history)
    # Receive image from user 
    uploaded_image = st.file_uploader(label="Upload image (Optional)", type='jpeg')
    image = None
    if uploaded_image:
        image = Image.open(uploaded_image)
         # Check if the image uploaded by the user relates to the lecture topic
        image_tags = image_model.generate_content(contents=["Generate a set of tags related to the input image.", image]).text
        image_tags = set(image_tags.lower().split())
        if not image_tags.intersection(transcript_tags):
            image = None
            st.warning("Please enter an image that relates to the topic of the lecture.")
    
    prompt = st.chat_input(placeholder="Enter prompt")
    if prompt:
        # Check if the user prompt relates to the lecture topic
        prompt_tags = set(prompt.lower().split())   
        if not prompt_tags.intersection(transcript_tags): 
            st.warning("Please enter a prompt that relates to the topic of the lecture.") 
        else:
            if image:
                response = image_model.generate_content(contents=[
                    """
                    You're an expert on the topic of the following lecture transcript. 
                    LECTURE TRANSCRIPT: """ + transcript + """
                    Generate a response to the following prompt and input image which may encompass broader aspects of the topic of the lecture.
                    PROMPT: """ + prompt + """
                    """, image]).text
                print(response)
                image_url = upload_image_to_gcs(image)
                insert_chat_message(st.session_state.current_chat_id, prompt, response, image_url=image_url)
                st.rerun()

            else:
                response = text_model.generate_content(
                    """
                    You're an expert on the topic of the following lecture transcript. 
                    LECTURE TRANSCRIPT: """ + transcript + """
                    Generate a response to the following prompt which may encompass broader aspects of the topic of the lecture.
                    PROMPT: """ + prompt + """
                    """).text
                print(response)
                insert_chat_message(st.session_state.current_chat_id, prompt, response) 
                st.rerun()