Class-Corder / auth.py
auth.py
Raw
import streamlit as st
import bcrypt
from google.cloud import bigquery

client = bigquery.Client(st.secrets.project)

st.session_state.status = st.session_state.get("status", "unverified")
st.session_state.username = st.session_state.get("username", "")

st.title("ClassCorder")
st.subheader("Record, Transcribe, and Understand Your Lectures")


def get_hashed_password(plain_text_password):
    # Hash a password for the first time
    #   (Using bcrypt, the salt is saved into the hash itself)
    tmp = bcrypt.hashpw(plain_text_password.encode('utf-8'), bcrypt.gensalt())
    return tmp.decode('utf-8')


def check_password():
    QUERY = (f"SELECT * FROM `{st.secrets.project}.{st.secrets.dataset}.Users` WHERE username = '" +
             st.session_state.username_field + "'")
    job_config = bigquery.QueryJobConfig(
        query_parameters=[
            bigquery.ScalarQueryParameter(
                "username", "STRING", st.session_state.username_field),
        ]
    )
    query_job = client.query(QUERY, job_config=job_config)
    rows = query_job.result()
    if rows.total_rows > 0:
        for row in rows:
            if bcrypt.checkpw(st.session_state.password_field.encode('utf-8'), row.password.encode('utf-8')):
                st.session_state.status = "verified"
                st.session_state.username = st.session_state.username_field
    if st.session_state.status != "verified":
        st.session_state.status = "incorrect"


def login_prompt():
    st.text_input("Enter username:", key="username_field")
    st.text_input("Enter password:", type="password", key="password_field")
    left, right = st.columns(2)
    with left:
        st.button("Create user", on_click=create_user, use_container_width=True)
    with right:
        st.button("Log in", on_click=check_password, use_container_width=True)
    if st.session_state.status == "username_too_short":
        st.warning("Username must have at least 2 characters")
    if st.session_state.status == "password_not_strong_enough":
        st.warning("Password must have at least 8 characters with at least one digit and one letter")
    if st.session_state.status == "incorrect":
        st.warning("Incorrect username and password. Please try again")
    if st.session_state.status == "user_already_exists":
        st.warning("Username already exists")


def logout():
    st.session_state.status = "unverified"


def create_user():
    # check if password has 8 character or more with at least one digit and one letter
    if len(st.session_state.username_field) < 2:
        st.session_state.status = "username_too_short"
        return
    if len(st.session_state.password_field) < 8 or not any(char.isdigit() for char in st.session_state.password_field) or not any(char.isalpha() for char in st.session_state.password_field):
        st.session_state.status = "password_not_strong_enough"
        return
    QUERY = (f"SELECT password FROM `{st.secrets.project}.{st.secrets.dataset}.Users` WHERE username = '" +
             st.session_state.username_field + "'")
    job_config = bigquery.QueryJobConfig(
        query_parameters=[
            bigquery.ScalarQueryParameter(
                "username", "STRING", st.session_state.username_field),
        ]
    )
    query_job = client.query(QUERY, job_config=job_config)
    rows = query_job.result()
    if rows.total_rows == 0:
        myquery = f"INSERT INTO `{st.secrets.project}.{st.secrets.dataset}.Users` (username, password) VALUES (\"%s\", \"%s\")" % (
            st.session_state.username_field, get_hashed_password(st.session_state.password_field))
        QUERY = (myquery)
        query_job = client.query(QUERY)  # API request
        rows = query_job.result()  # Waits for query to finish
        st.session_state.status = "unverified"
    else:
        st.session_state.status = "user_already_exists"


def welcome():
    st.success("Login successful.")
    st.switch_page("pages/dashboard.py")
    # st.button("Log out", on_click=logout)

if st.session_state.status != "verified":
    login_prompt()
    st.stop()
welcome()