import streamlit as st
from tools import *
import pandas as pd
from pycaret.regression import RegressionExperiment
from pycaret.classification import ClassificationExperiment
import time
from streamlit_extras.stateful_button import button
import subprocess
_, file_tree_ph = sidebar(globals())
st.title("Federated Learning Test")
exp_id = st.session_state.get('exp_id')
if exp_id is None:
st.write("Please select an experiment first.")
st.stop()
use_synthetic_data = st.session_state.get('use_synth_data')
df = pd.read_csv(
PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET)
with st.container(border=True):
TARGET = st.selectbox(
'Choose the Target Column', df.columns, key="preserve_chosen_target", index=len(df.columns)-1)
PROBLEM_TYPE = st.selectbox(
"Problem Type", ["classification", "regression"], key="preserve_chosen_problem", index=identify_problem_type(df) == 'regression')
if st.session_state.get('preserve_chosen_problem') == 'classification':
exp = ClassificationExperiment()
elif st.session_state.get('preserve_chosen_problem') == 'regression':
exp = RegressionExperiment()
exp.setup(df, target=TARGET, verbose=False,
ignore_features=[], experiment_name=exp_id)
NUM_CLIENTS = st.number_input(
"Number of Clients", min_value=3, max_value=10, value=3, key="preserve_num_clients")
NUM_ROUNDS = st.number_input(
"Number of Rounds", min_value=1, max_value=1000, value=10, key="preserve_num_rounds")
models = exp.models()
UNSUPPORTED_MODEL = ["catboost"]
# find an intersection between the models supported by pycaret and the models supported by the server
SUPPORTED_MODEL = list(set(models.index).difference(
set(UNSUPPORTED_MODEL)))
if (st.session_state.get('preserve_chosen_model') not in SUPPORTED_MODEL):
st.session_state['preserve_chosen_model'] = SUPPORTED_MODEL[0]
MODEL_NAME = st.selectbox(
'Choose the Model', options=SUPPORTED_MODEL, format_func=lambda x: models.loc[x, 'Name'], key="preserve_chosen_model")
def run_command(args):
"""Run command, transfer stdout/stderr back into Streamlit and manage error"""
st.info(f"Running '{' '.join(args)}'")
expander = st.expander("Output", expanded=True)
container = expander.container(height=300, border=False)
# Custom CSS for changing the text style
custom_css = """
<style>
[data-testid="stExpander"] [data-testid=stText] {
white-space: pre-wrap !important;
}
</style>
"""
# Apply custom CSS
st.markdown(custom_css, unsafe_allow_html=True)
process = subprocess.Popen(
args=args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
for line in process.stdout:
line = line.decode()
container.text(line)
btn = st.button("Run Test")
if (btn):
# measure time
start = time.time()
run_command([f"{PATH_TO_FL}/run.sh", str(NUM_ROUNDS), str(NUM_CLIENTS),
MODEL_NAME, PROBLEM_TYPE, TARGET, exp_id, "." + (PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET), "."+PATH_TO_MODELS])
end = time.time()
st.balloons()
st.success("Training is done!")
st.info(f"Time taken: {end-start} seconds")
update_file_tree(file_tree_ph, globals())
fl_results = None
pth = f"{PATH_TO_FL}/results/{exp_id}-results-{MODEL_NAME}.csv"
if os.path.exists(pth):
fl_results = pd.read_csv(pth)
if (fl_results is not None):
st.subheader("Results")
fl_results.insert(0, 'Model', fl_results.pop('Model'))
st.dataframe(fl_results, hide_index=True, use_container_width=True)