Class-Corder / pages / dashboard.py
dashboard.py
Raw
import json
import uuid
import streamlit as st
import google.generativeai as gai
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel
import textwrap
from google.cloud import speech, bigquery, storage
from st_audiorec import st_audiorec
import datetime
from utils import check_login, generate_uuid, json_validator

# Redirect to the authentication page if the user is not logged in.
check_login(st)
    
storage_client = storage.Client()
bigquery_client = bigquery.Client(st.secrets.project)
bucket = storage_client.bucket(st.secrets.bucket)
gai.configure(api_key=st.secrets.gai_api_key)

# Initialize session state variables

st.session_state.audio_file = st.session_state.get("audio_file", None)
st.session_state.audio_file_name = st.session_state.get("audio_file_name", '')

st.session_state.audio_language = st.session_state.get("audio_language", 'en-US')

st.session_state.transcript = st.session_state.get("transcript", '')

st.session_state.title = st.session_state.get("title", '')

st.session_state.summary = st.session_state.get("summary", '')
    
st.session_state.imagen_prompt = st.session_state.get("imagen_prompt", '')

st.session_state.created_at = st.session_state.get("created_at", '')

st.session_state.thumbnail_filename = st.session_state.get("thumbnail_filename", None)

st.session_state.lectures = st.session_state.get("lectures", [])

st.session_state.updateSidebar = st.session_state.get("updateSidebar", False)

st.session_state.lecture_id = st.session_state.get("lecture_id", '')

language_map = {
    "English": "en-US",
    "Spanish": "es-ES",
    "French": "fr-FR"
}

audio_format_map = {
    "wav": speech.RecognitionConfig.AudioEncoding.LINEAR16,
    "mp3": speech.RecognitionConfig.AudioEncoding.MP3,
    "m4a": speech.RecognitionConfig.AudioEncoding.MP3
}

def upload_file(file, file_path):
    blob = bucket.blob(file_path)
    blob.upload_from_string(data=file, timeout=86400)
    return blob

def generate_text(prompt):
    model = gai.GenerativeModel('gemini-pro')
    output = model.generate_content(prompt)
    return output.text

# Transcribe the given audio file.
def transcribe_file():
    speech_client = speech.SpeechClient()
    
    audio_file = st.session_state.audio_file
    
    st.session_state.audio_file_name = f"{st.session_state.lecture_id}.{audio_file.name.split('.')[-1]}"

    blob = upload_file(audio_file.getvalue(), f"audio_files/{st.session_state.audio_file_name}")
    
    gcs_uri = 'gs://' + blob.id[:-(len(str(blob.generation)) + 1)]

    audio = speech.RecognitionAudio(uri=gcs_uri)
    config = speech.RecognitionConfig(
        encoding=audio_format_map[audio_file.name.split(".")[-1]],
        sample_rate_hertz=16000,
        language_code=st.session_state.audio_language,
    )

    operation = speech_client.long_running_recognize(config=config, audio=audio)

    response = operation.result(timeout=3600)

    st.session_state.transcript = ''
    for result in response.results:
        st.session_state.transcript += (f"{result.alternatives[0].transcript}\n")

# Generate metadata and an Imagen prompt for the lecture transcript.
def generate_metadata(attempt=1): 
    final_prompt = textwrap.dedent(
        """
        Your goal is to generate a title, summary, and prompt for an image generator based on a given lecture transcript. The title and summary must describe the contents of the lecture while being concise at the same time. The summary must be up to 3 sentences long. The prompt for the image generator must be descriptive and focus on the main subject of the lecture. The output must be outputted in the form of a JSON object as such:

        {
            "title": "Lecture title",
            "summary": "Lecture summary.",
            "prompt": "Prompt for image generator"
        }

        Do not include single quotes or curly quotes in the output. The output must be in the exact format as shown above.

        INPUT: """ + st.session_state.transcript + """

        OUTPUT JSON:
        """)

    output_json = generate_text(final_prompt)
    try:
        output_dict = json.loads(output_json)
        
        if not json_validator(output_json, ["title", "summary", "prompt"]):
            raise json.JSONDecodeError

        st.session_state.title = output_dict["title"]
        st.session_state.summary = output_dict["summary"]
        st.session_state.imagen_prompt = output_dict["prompt"]
    except json.JSONDecodeError:
        if attempt < 3:
            generate_metadata(attempt + 1)
        else:
            st.error("Failed to generate metadata. Please try again.")
    
def clean_up_transcript(): 
    final_prompt = textwrap.dedent(
        """
        Your goal is to clean up a given lecture transcript. The transcript may contain grammatical errors or spelling mistakes. 
        The output must be a cleaned-up version of the transcript with proper sentences and capitalization, 
        but it must still maintain the original structure and content of the transcript.

        INPUT: """ + st.session_state.transcript + """

        OUTPUT TRANSCRIPT:
        """)

    st.session_state.transcript = generate_text(final_prompt)
    
def generate_thumbnail():
    vertexai.init(project=st.secrets.project, location=st.secrets.location)
    model = ImageGenerationModel.from_pretrained("imagegeneration@005")
    print(st.session_state.imagen_prompt)
    images = model.generate_images(st.session_state.imagen_prompt)
    if images:
        file_name = f"{st.session_state.lecture_id}.png"
        st.session_state.thumbnail_filename = file_name
        images[0].save(file_name)
        upload_file(open(file_name, "rb").read(), f"thumbnails/{file_name}")

def upload_to_database():    
    audio_url = f"gs://{st.secrets.bucket}/audio_files/{st.session_state.audio_file_name}"
    thumbnail_url = f"gs://{st.secrets.bucket}/thumbnails/{st.session_state.lecture_id}.png"
    
    query = """
    INSERT INTO `{project}.{dataset}.Lectures` 
    (Lecture_id, Username, Title, Summary, Transcript, Created_at, Audio_url, Thumbnail_url) 
    VALUES 
    (?, ?, ?, ?, ?, ?, ?, ?)
    """.format(project=st.secrets.project, dataset=st.secrets.dataset)
    
    values = (
        st.session_state.lecture_id,
        st.session_state.username,
        st.session_state.title,
        st.session_state.summary,
        st.session_state.transcript,
        st.session_state.created_at,
        audio_url,
        thumbnail_url
    )
    
    job_config = bigquery.QueryJobConfig()
    job_config.query_parameters = [bigquery.ScalarQueryParameter(None, 'STRING', value) for value in values]
    query_job = bigquery_client.query(query, job_config=job_config)
    rows = query_job.result()
    st.session_state.updateSidebar = True
    
# load 
def load_lecture_from_row(row):
    st.session_state.lecture_id = row.Lecture_id
    st.session_state.title = row.Title
    st.session_state.summary = row.Summary
    st.session_state.transcript = row.Transcript
    st.session_state.created_at = row.Created_at
    
    tb_file_path = row.Thumbnail_url.split(f"gs://{st.secrets.bucket}/")[1]
    tb_file_name = tb_file_path.split("/")[-1]
    tb_blob = bucket.blob(tb_file_path)
    tb_blob.download_to_filename(tb_file_name)
    
    af_file_path = row.Audio_url.split(f"gs://{st.secrets.bucket}/")[1]
    af_blob = bucket.blob(af_file_path)
    st.session_state.audio_file = af_blob.download_as_bytes()
    
    st.session_state.thumbnail_filename = tb_file_name

st.title("ClassCorder")
st.subheader("Record, Transcribe, and Understand Your Lectures")
    
st.divider()

st.header("Record Lecture")
audio_source = st.radio("Choose audio input:", ("Upload from file", "Record from microphone"))
if audio_source == "Record from microphone":
    st.session_state.audio_file = st_audiorec()
else:
    st.session_state.audio_file = st.file_uploader("Upload audio file", type=["wav", "mp3", "m4a"])
lecture_language = st.selectbox("Select lecture language:", ("English", "Spanish", "French"))
st.session_state.audio_language = language_map[lecture_language]
upload = st.button("Upload and process")

if st.session_state.audio_file is not None and upload:
    st.session_state.lecture_id = generate_uuid()
    st.session_state.created_at = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with st.status("Processing lecture audio..."):
        st.write("Transcribing audio...")
        transcribe_file()
        st.write("Cleaning up transcript...")
        clean_up_transcript()
        st.write("Generating metadata...")
        generate_metadata()
        st.write("Generating thumbnail...")
        generate_thumbnail()
        st.write("Uploading to database...")
        upload_to_database()
elif upload:
    st.error("Please upload an audio file or create an audio recording.")

st.divider()

with st.sidebar:
    st.header("Your Lectures")
    search_query = st.text_input("Search lectures", placeholder="Search by title, keywords...")
    
    if not st.session_state.lectures or st.session_state.updateSidebar:
        QUERY = (f"SELECT * FROM `{st.secrets.project}.{st.secrets.dataset}.Lectures` WHERE Username = '" +
                st.session_state.username + "'")
        
        query_job = bigquery_client.query(QUERY)
        st.session_state.lectures = list(query_job.result())
        st.session_state.updateSidebar = False
    
    if len(st.session_state.lectures) > 0:
        # filter rows based on search query
        rows = [row for row in st.session_state.lectures if search_query.lower() in row.Title.lower() or search_query.lower() in row.Summary.lower()]
        for row in rows:
            left, right = st.columns([1, 3])
            with left:
                file_path = row.Thumbnail_url[6+len(st.secrets.bucket):]
                file_name = file_path.split("/")[-1]
                blob = bucket.blob(file_path)
                blob.download_to_filename(file_name)
                st.image(file_name, use_column_width=True)
            with right:
                st.write(row.Title)
                st.caption(row.Created_at)
                st.button("Select", key=f"view-{row.Lecture_id}", on_click=load_lecture_from_row, args=(row,))

st.header("Lecture Details")

if st.session_state.transcript:
    left_col, right_col = st.columns(2)
    with left_col:
        st.image(st.session_state.thumbnail_filename)
    with right_col:
        st.header(st.session_state.title)
        st.caption(st.session_state.created_at)
        st.write(st.session_state.summary)
        left_button, right_button = st.columns(2)
        with left_button:
            st.page_link("pages/chat.py", label="Open Chat Page")
        with right_button:
            st.page_link("pages/quiz.py", label="Open Quiz Page")
        
    st.subheader("Audio Recording")
        
    st.audio(st.session_state.audio_file)
    
    st.download_button(label="Download Transcript", 
                       data=st.session_state.transcript, 
                       file_name=f"{st.session_state.lecture_id}.txt", 
                       mime="text/plain")

    with st.expander("View Transcript"):
        translate_language = st.selectbox("Translate transcript:", ("Don't translate", "English", "Spanish", "French"))
        transcript = st.session_state.transcript
        if translate_language != "Don't translate":
            transcript = generate_text(f"Translate the following transcript from its current language to {language_map[translate_language]}:\n\n{transcript}")
        st.write(transcript)
else:
    st.write("Please select a lecture or upload a new one.")