import streamlit as st from tools import * import pandas as pd from pycaret.regression import RegressionExperiment from pycaret.classification import ClassificationExperiment import time from streamlit_extras.stateful_button import button import subprocess _, file_tree_ph = sidebar(globals()) st.title("Federated Learning Test") exp_id = st.session_state.get('exp_id') if exp_id is None: st.write("Please select an experiment first.") st.stop() use_synthetic_data = st.session_state.get('use_synth_data') df = pd.read_csv( PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET) with st.container(border=True): TARGET = st.selectbox( 'Choose the Target Column', df.columns, key="preserve_chosen_target", index=len(df.columns)-1) PROBLEM_TYPE = st.selectbox( "Problem Type", ["classification", "regression"], key="preserve_chosen_problem", index=identify_problem_type(df) == 'regression') if st.session_state.get('preserve_chosen_problem') == 'classification': exp = ClassificationExperiment() elif st.session_state.get('preserve_chosen_problem') == 'regression': exp = RegressionExperiment() exp.setup(df, target=TARGET, verbose=False, ignore_features=[], experiment_name=exp_id) NUM_CLIENTS = st.number_input( "Number of Clients", min_value=3, max_value=10, value=3, key="preserve_num_clients") NUM_ROUNDS = st.number_input( "Number of Rounds", min_value=1, max_value=1000, value=10, key="preserve_num_rounds") models = exp.models() UNSUPPORTED_MODEL = ["catboost"] # find an intersection between the models supported by pycaret and the models supported by the server SUPPORTED_MODEL = list(set(models.index).difference( set(UNSUPPORTED_MODEL))) if (st.session_state.get('preserve_chosen_model') not in SUPPORTED_MODEL): st.session_state['preserve_chosen_model'] = SUPPORTED_MODEL[0] MODEL_NAME = st.selectbox( 'Choose the Model', options=SUPPORTED_MODEL, format_func=lambda x: models.loc[x, 'Name'], key="preserve_chosen_model") def run_command(args): """Run command, transfer stdout/stderr back into Streamlit and manage error""" st.info(f"Running '{' '.join(args)}'") expander = st.expander("Output", expanded=True) container = expander.container(height=300, border=False) # Custom CSS for changing the text style custom_css = """ """ # Apply custom CSS st.markdown(custom_css, unsafe_allow_html=True) process = subprocess.Popen( args=args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) for line in process.stdout: line = line.decode() container.text(line) btn = st.button("Run Test") if (btn): # measure time start = time.time() run_command([f"{PATH_TO_FL}/run.sh", str(NUM_ROUNDS), str(NUM_CLIENTS), MODEL_NAME, PROBLEM_TYPE, TARGET, exp_id, "." + (PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET), "."+PATH_TO_MODELS]) end = time.time() st.balloons() st.success("Training is done!") st.info(f"Time taken: {end-start} seconds") update_file_tree(file_tree_ph, globals()) fl_results = None pth = f"{PATH_TO_FL}/results/{exp_id}-results-{MODEL_NAME}.csv" if os.path.exists(pth): fl_results = pd.read_csv(pth) if (fl_results is not None): st.subheader("Results") fl_results.insert(0, 'Model', fl_results.pop('Model')) st.dataframe(fl_results, hide_index=True, use_container_width=True)