auto-fl-fit / tools.py
tools.py
Raw
import sys
import hashlib
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
import json
import os
import numpy as np
from directory_tree import display_tree

from pycaret.regression import RegressionExperiment
from pycaret.classification import ClassificationExperiment

PATH_TO_FL = "./pycaret-fl"
CONST_PATH_TO_DATA = "./data"
CONST_PATH_TO_MODELS = "./models"


def init_paths(scope, exp_id=None):
    if exp_id:
        PATH_TO_DATA = os.path.join(CONST_PATH_TO_DATA, exp_id)
        PATH_TO_MODELS = os.path.join(CONST_PATH_TO_MODELS, exp_id)
    else:
        PATH_TO_DATA = CONST_PATH_TO_DATA
        PATH_TO_MODELS = CONST_PATH_TO_MODELS

    if not os.path.exists(PATH_TO_DATA):
        os.makedirs(PATH_TO_DATA)
    if not os.path.exists(PATH_TO_MODELS):
        os.makedirs(PATH_TO_MODELS)

    scope['PATH_TO_DATA'] = PATH_TO_DATA
    scope['PATH_TO_MODELS'] = PATH_TO_MODELS

    scope['PATH_TO_DATASET'] = f"{PATH_TO_DATA}/dataset.csv"
    scope['PATH_TO_TRAIN_DATASET'] = f"{PATH_TO_DATA}/train_dataset.csv"
    scope['PATH_TO_EVAL_DATASET'] = f"{PATH_TO_DATA}/eval_dataset.csv"
    scope['PATH_TO_GEN_DATASET'] = f"{PATH_TO_DATA}/gen_dataset.csv"


def init_session():
    if 'use_synth_data' not in st.session_state:
        st.session_state['use_synth_data'] = False

    if 'exp_id' not in st.session_state:
        # load from config.json
        if os.path.exists('config.json'):
            with open('config.json', 'r') as f:
                config = json.load(f)
                st.session_state['exp_id'] = config.get('exp_id')
        else:
            st.session_state['exp_id'] = None


def update_exp_id(exp_id_ph: DeltaGenerator, exp_id: str):
    if exp_id:
        exp_id_ph.caption(f"Experiment ID: *{exp_id}*")
    else:
        exp_id_ph.info("Experiment not initialized, please upload a dataset")


def update_file_tree(ph, scope):
    if (ptd := scope.get('PATH_TO_DATA')) and (ptm := scope.get('PATH_TO_MODELS')):
        import seedir as sd

        fd = sd.FakeDir("Structure")
        fdd = fd.create_folder('data')
        fdm = fd.create_folder('models')

        fptd = sd.fakedir(ptd)
        fdd._children = fptd._children
        fptm = sd.fakedir(ptm)
        fdm._children = fptm._children

        filetree = fd.seedir(style='emoji', printout=False)
        container = ph.container(height=310, border=True)
        container.text(filetree)


def sidebar(scope):
    st.set_page_config(layout="wide")

    init_session()

    # Custom CSS for changing the sidebar navigation
    custom_css = """
    <style>
        [data-testid=stSidebarContent] {
            overflow: overlay !important;
        }
        [data-testid=stSidebarNavItems] {
            max-height: 75vh !important;
        }
        [data-testid=stSidebarUserContent] {
            padding-bottom: 1rem !important;
        }
    </style>
    """
    # Apply custom CSS
    st.markdown(custom_css, unsafe_allow_html=True)

    with st.sidebar:
        st.checkbox("Use synthetic data", key="use_synth_data",
                    value=st.session_state['use_synth_data'], on_change=reset_run_btn)

        exp_id = st.session_state['exp_id']
        init_paths(scope, exp_id)

        file_tree_ph = st.empty()
        update_file_tree(file_tree_ph, scope)

        exp_id_ph = st.empty()
        update_exp_id(exp_id_ph, exp_id)

        return exp_id_ph, file_tree_ph


def show_df(df):
	st.dataframe(df, use_container_width=True)


def identify_problem_type(df, threshold=10):
	# Check if the unique values in the column are less than a certain threshold
	chosen_target = st.session_state.get('preserve_chosen_target')
	unique_values = np.unique(df[chosen_target])
	if len(unique_values) < threshold:
		return "classification"
	return "regression"


def get_plot_types(problem_type):
    regression_plots = ["residuals", "error", "cooks",
                        "rfe", "learning", "vc", "manifold", "feature", "feature_all", "parameter"]

    regression_plots = {
        # 'Schematic drawing of the preprocessing pipeline': 'pipeline',
        'Interactive Residual plots': 'residuals_interactive',
        'Residuals Plot': 'residuals',
        'Prediction Error Plot': 'error',
        'Cooks Distance Plot': 'cooks',
        'Recursive Feat. Selection': 'rfe',
        'Learning Curve': 'learning',
        'Validation Curve': 'vc',
        'Manifold Learning': 'manifold',
        'Feature Importance': 'feature',
        'Feature Importance (All)': 'feature_all',
        'Model Hyperparameter': 'parameter',
        'Decision Tree': 'tree',
    }

    classification_plots = {
        # 'Schematic drawing of the preprocessing pipeline': 'pipeline',
        'Area Under the Curve': 'auc',
        'Discrimination Threshold': 'threshold',
        'Precision Recall Curve': 'pr',
        'Confusion Matrix': 'confusion_matrix',
        'Class Prediction Error': 'error',
        'Classification Report': 'class_report',
        'Decision Boundary': 'boundary',
        'Recursive Feature Selection': 'rfe',
        'Learning Curve': 'learning',
        'Manifold Learning': 'manifold',
        'Calibration Curve': 'calibration',
        'Validation Curve': 'vc',
        'Dimension Learning': 'dimension',
        'Feature Importance': 'feature',
        'Feature Importance (All)': 'feature_all',
        'Model Hyperparameter': 'parameter',
        'Lift Curve': 'lift',
        'Gain Chart': 'gain',
        'Decision Tree': 'tree',
        'KS Statistic Plot': 'ks',
    }

    if problem_type == 'classification':
        return classification_plots
    elif problem_type == 'regression':
        return regression_plots


def reset_run_btn():
    st.session_state['run_mod_btn'] = False


@st.cache_resource()
def create_experiment(chosen_problem, df, target, ignore_features, exp_id):
	if chosen_problem == 'classification':
		exp = ClassificationExperiment()
	elif chosen_problem == 'regression':
		exp = RegressionExperiment()

	exp.setup(df, target=target, verbose=False,
           ignore_features=ignore_features, experiment_name=exp_id)
	return exp