MiniTorch / project / run_mnist_interface.py
run_mnist_interface.py
Raw
import time

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from run_mnist_multiclass import ImageTrain, make_mnist


def render_run_image_interface():

    st.markdown("### Dataset")
    n_training_samples = st.number_input(
        "Number of training samples",
        min_value=100,
        max_value=10000,
        step=10,
        value=1000,
    )
    (X_train, y_train) = make_mnist(0, n_training_samples)

    show = st.number_input("Image", min_value=0, max_value=100, step=1, value=1)
    st.write(
        px.imshow(X_train[show], title="y =" + str([int(i) for i in y_train[show]]))
    )

    st.markdown("### Hyperparameters")
    col1, col2 = st.columns(2)
    learning_rate = col1.selectbox(
        "Learning rate", [0.001, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0], index=2
    )
    max_epochs = col2.number_input("Number of epochs", min_value=1, step=25, value=500)

    col1, col2 = st.columns(2)
    st_train_button = col1.empty()
    col2.button("Stop Model")

    st_progress = st.empty()
    st_epoch_timer = st.empty()
    st_epoch_image = st.empty()
    st_epoch_plot = st.empty()
    st_epoch_stats = st.empty()

    df = []

    if st_train_button.button("Train Model"):
        start_time = time.time()
        train = ImageTrain()

        def log_fn(epoch, total_loss, correct, losses, model):
            time_elapsed = time.time() - start_time
            st_progress.progress(epoch / max_epochs)
            time_per_epoch = time_elapsed / (epoch + 1)
            st_epoch_timer.markdown(
                "Epoch {}/{}. Time per epoch: {:,.3f}s. Time left: {:,.2f}s.".format(
                    epoch,
                    max_epochs,
                    time_per_epoch,
                    (max_epochs - epoch) * time_per_epoch,
                )
            )
            df.append({"epoch": epoch, "loss": total_loss, "correct": correct})
            st_epoch_stats.write(pd.DataFrame(reversed(df)))

            # Visualize test batch
            fig = px.imshow(-1 * model.mid.to_numpy()[0], facet_col=0)
            st_epoch_image.plotly_chart(fig)

            loss_graph = go.Scatter(mode="lines", x=list(range(len(losses))), y=losses)
            fig = go.Figure(loss_graph)
            fig.update_layout(
                title="Loss Graph",
                xaxis=dict(range=[0, max_epochs]),
                yaxis=dict(range=[0, max(losses)]),
            )
            st_epoch_plot.plotly_chart(fig)
            print(
                f"Epoch: {epoch}/{max_epochs}, loss: {total_loss}, correct: {correct}"
            )

        train.train(
            (X_train, y_train),
            make_mnist(10000, 10500),
            learning_rate,
            max_epochs,
            log_fn,
        )