auto-fl-fit / pages / 4_🧠🔧_Centralized_Learning.py
4_🧠🔧_Centralized_Learning.py
Raw
import streamlit as st
from tools import *
import pandas as pd
from pycaret.regression import RegressionExperiment
from pycaret.classification import ClassificationExperiment
import time
from streamlit_extras.stateful_button import button


_, file_tree_ph = sidebar(globals())

st.title("Model Building")

exp_id = st.session_state.get('exp_id')
if exp_id is None:
    st.write("Please select an experiment first.")
    st.stop()

use_synthetic_data = st.session_state.get('use_synth_data')
df = pd.read_csv(
    PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET)

with st.container(border=True):
    st.selectbox(
        'Choose the Target Column', df.columns, key="preserve_chosen_target", index=len(df.columns)-1, on_change=reset_run_btn)
    st.multiselect('Choose the Ignore Columns',
                   df.columns, key="preserve_chosen_ignore", on_change=reset_run_btn)
    st.selectbox('Choose the Problem Type', [
        'classification', 'regression'], key="preserve_chosen_problem", index=identify_problem_type(df) == 'regression', on_change=reset_run_btn)
    st.checkbox('Cross Validation', key="preserve_cv", on_change=reset_run_btn)

run_mod_btn = button('Run Modelling', key="run_mod_btn")

if run_mod_btn:
    if st.session_state.get('preserve_chosen_problem') == 'classification':
        exp = ClassificationExperiment()
    else:
        exp = RegressionExperiment()

    exp.setup(df, target=st.session_state.get("preserve_chosen_target"), verbose=False,
              ignore_features=st.session_state.get("preserve_chosen_ignore"), experiment_name=exp_id)
    setup_df = exp.pull()
    show_df(setup_df)

    # measure time
    start = time.time()
    best_model = exp.compare_models(
        cross_validation=st.session_state.get('preserve_cv'))
    end = time.time()
    st.info(f"Time taken: {end-start} seconds")
    compare_df = exp.pull()
    show_df(compare_df)
    exp.save_model(
        best_model, f'{PATH_TO_MODELS}/best_{compare_df.index[0]}')

    # plot_types = get_plot_types(
    #     st.session_state.get('preserve_chosen_problem'))
    # tabs = st.tabs(plot_types.keys())
    # for k, tab in zip(plot_types.keys(), tabs):
    #     with tab:
    #         try:
    #             exp.plot_model(best_model, display_format='streamlit',
    #                            plot=plot_types[k])
    #         except Exception as e:
    #             st.error(e)

    update_file_tree(file_tree_ph, globals())

    # models = compare_df.to_dict()['Model']
    # st.session_state['preserve_models'] = models