auto-fl-fit / pages / 6_๐ŸŒ๐Ÿค–_Federated_Learning.py
6_๐ŸŒ๐Ÿค–_Federated_Learning.py
Raw
import streamlit as st
from tools import *
import pandas as pd
from pycaret.regression import RegressionExperiment
from pycaret.classification import ClassificationExperiment
import time
from streamlit_extras.stateful_button import button
import subprocess


_, file_tree_ph = sidebar(globals())

st.title("Federated Learning Test")

exp_id = st.session_state.get('exp_id')
if exp_id is None:
    st.write("Please select an experiment first.")
    st.stop()

use_synthetic_data = st.session_state.get('use_synth_data')
df = pd.read_csv(
    PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET)

with st.container(border=True):
    TARGET = st.selectbox(
        'Choose the Target Column', df.columns, key="preserve_chosen_target", index=len(df.columns)-1)

    PROBLEM_TYPE = st.selectbox(
        "Problem Type", ["classification", "regression"], key="preserve_chosen_problem", index=identify_problem_type(df) == 'regression')
    if st.session_state.get('preserve_chosen_problem') == 'classification':
        exp = ClassificationExperiment()
    elif st.session_state.get('preserve_chosen_problem') == 'regression':
        exp = RegressionExperiment()
    exp.setup(df, target=TARGET, verbose=False,
              ignore_features=[], experiment_name=exp_id)

    NUM_CLIENTS = st.number_input(
        "Number of Clients", min_value=3, max_value=10, value=3, key="preserve_num_clients")
    NUM_ROUNDS = st.number_input(
        "Number of Rounds", min_value=1, max_value=1000, value=10, key="preserve_num_rounds")

    models = exp.models()
    UNSUPPORTED_MODEL = ["catboost"]
    # find an intersection between the models supported by pycaret and the models supported by the server
    SUPPORTED_MODEL = list(set(models.index).difference(
        set(UNSUPPORTED_MODEL)))

    if (st.session_state.get('preserve_chosen_model') not in SUPPORTED_MODEL):
        st.session_state['preserve_chosen_model'] = SUPPORTED_MODEL[0]

    MODEL_NAME = st.selectbox(
        'Choose the Model', options=SUPPORTED_MODEL, format_func=lambda x: models.loc[x, 'Name'], key="preserve_chosen_model")


def run_command(args):
    """Run command, transfer stdout/stderr back into Streamlit and manage error"""
    st.info(f"Running '{' '.join(args)}'")

    expander = st.expander("Output", expanded=True)
    container = expander.container(height=300, border=False)

    # Custom CSS for changing the text style
    custom_css = """
    <style>
        [data-testid="stExpander"] [data-testid=stText] {
            white-space: pre-wrap !important;
        }
    </style>
    """
    # Apply custom CSS
    st.markdown(custom_css, unsafe_allow_html=True)

    process = subprocess.Popen(
        args=args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
    )
    for line in process.stdout:
        line = line.decode()
        container.text(line)


btn = st.button("Run Test")
if (btn):
    # measure time
    start = time.time()
    run_command([f"{PATH_TO_FL}/run.sh", str(NUM_ROUNDS), str(NUM_CLIENTS),
                 MODEL_NAME, PROBLEM_TYPE, TARGET, exp_id, "." + (PATH_TO_GEN_DATASET if use_synthetic_data else PATH_TO_TRAIN_DATASET), "."+PATH_TO_MODELS])
    end = time.time()
    st.balloons()
    st.success("Training is done!")
    st.info(f"Time taken: {end-start} seconds")
    update_file_tree(file_tree_ph, globals())

fl_results = None
pth = f"{PATH_TO_FL}/results/{exp_id}-results-{MODEL_NAME}.csv"
if os.path.exists(pth):
    fl_results = pd.read_csv(pth)

if (fl_results is not None):
    st.subheader("Results")
    fl_results.insert(0, 'Model', fl_results.pop('Model'))
    st.dataframe(fl_results, hide_index=True, use_container_width=True)