import streamlit as st from tools import * import pandas as pd from pycaret.regression import RegressionExperiment from pycaret.classification import ClassificationExperiment import time from streamlit_extras.stateful_button import button _, file_tree_ph = sidebar(globals()) st.title("Model Building") exp_id = st.session_state.get('exp_id') if exp_id is None: st.write("Please select an experiment first.") st.stop() use_synthetic_data = st.session_state.get('use_synth_data') df = pd.read_csv( PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET) with st.container(border=True): st.selectbox( 'Choose the Target Column', df.columns, key="preserve_chosen_target", index=len(df.columns)-1, on_change=reset_run_btn) st.multiselect('Choose the Ignore Columns', df.columns, key="preserve_chosen_ignore", on_change=reset_run_btn) st.selectbox('Choose the Problem Type', [ 'classification', 'regression'], key="preserve_chosen_problem", index=identify_problem_type(df) == 'regression', on_change=reset_run_btn) st.checkbox('Cross Validation', key="preserve_cv", on_change=reset_run_btn) run_mod_btn = button('Run Modelling', key="run_mod_btn") if run_mod_btn: if st.session_state.get('preserve_chosen_problem') == 'classification': exp = ClassificationExperiment() else: exp = RegressionExperiment() exp.setup(df, target=st.session_state.get("preserve_chosen_target"), verbose=False, ignore_features=st.session_state.get("preserve_chosen_ignore"), experiment_name=exp_id) setup_df = exp.pull() show_df(setup_df) # measure time start = time.time() best_model = exp.compare_models( cross_validation=st.session_state.get('preserve_cv')) end = time.time() st.info(f"Time taken: {end-start} seconds") compare_df = exp.pull() show_df(compare_df) exp.save_model( best_model, f'{PATH_TO_MODELS}/best_{compare_df.index[0]}') # plot_types = get_plot_types( # st.session_state.get('preserve_chosen_problem')) # tabs = st.tabs(plot_types.keys()) # for k, tab in zip(plot_types.keys(), tabs): # with tab: # try: # exp.plot_model(best_model, display_format='streamlit', # plot=plot_types[k]) # except Exception as e: # st.error(e) update_file_tree(file_tree_ph, globals()) # models = compare_df.to_dict()['Model'] # st.session_state['preserve_models'] = models