auto-fl-fit / pages / 2_๐Ÿ”„๐Ÿ“Š_Generate_Dataset.py
2_๐Ÿ”„๐Ÿ“Š_Generate_Dataset.py
Raw
import streamlit as st
import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.lite import SingleTablePreset
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot
from tools import *
from sdv.single_table import CTGANSynthesizer, GaussianCopulaSynthesizer
from sdv.sequential import PARSynthesizer
from imblearn.over_sampling import SMOTE
import smogn

_, file_tree_ph = sidebar(globals())

st.title("Generate Synthetic Data")

exp_id = st.session_state.get('exp_id')
if exp_id is None:
    st.write("Please select an experiment first.")
    st.stop()

# Read the uploaded dataset into a DataFrame
real_data = pd.read_csv(PATH_TO_TRAIN_DATASET)

# Add patient id to  rows with the same age, gender, weight, height
# patient_id = (real_data.groupby(['age', 'gender', 'weight', 'height'])
#               .ngroup() + 1)
# real_data.insert(0, 'patient_id', patient_id)


# Display the dataset
st.header("Real Data")
show_df(real_data)

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(real_data)

# st.json(metadata.to_dict())
metadata.visualize(output_filepath='./assets/my_metadata.png')

st.image('./assets/my_metadata.png', width=300)

# Input field for the user to specify the number of rows for synthetic data generation
number = st.number_input('Number of rows', min_value=0, step=1000)

if number == 0:
    st.stop()


@st.cache_resource()
def load_synthesizer(_metadata):
    # synthesizer = SingleTablePreset(
    #     metadata,
    #     name='FAST_ML')

    # metadata.update_column(
    #     column_name='patient_id',
    #     sdtype='id',
    #     regex_format="\d+")
    # metadata.set_sequence_key(column_name='patient_id')
    # synthesizer = PARSynthesizer(
    #     metadata,  # required
    #     enforce_min_max_values=True,
    #     enforce_rounding=False,
    #     context_columns=[]
    # )

    synthesizer = GaussianCopulaSynthesizer(
        metadata,  # required
        enforce_min_max_values=True,
        enforce_rounding=False,
        # numerical_distributions={
        #     'amenities_fee': 'beta',
        #     'checkin_date': 'uniform'
        # },
        default_distribution='norm'
    )

    return synthesizer
    # return CTGANSynthesizer(_metadata)


@st.cache_data()
def get_synthetic_data(_synthesizer, real_data, number):
    _synthesizer.fit(
        data=real_data
    )
    # Get num of rows
    synthetic_data = _synthesizer.sample(
        num_rows=number
    )
    return synthetic_data


@st.cache_data()
def get_smote_data(_real_data, _target_column):
    sm = SMOTE(random_state=42)

    X = _real_data.drop(_target_column, axis=1)
    y = _real_data[_target_column]

    X_res, y_res = sm.fit_resample(X, y)

    synthetic_data = pd.concat([X_res, y_res], axis=1)

    return synthetic_data


@st.cache_data()
def get_smogn_data(real_data, target_column):
    synthetic_data = smogn.smoter(
        data=real_data, y=target_column)

    return synthetic_data


synth_opts = ["Gaussian Copula", "SMOTE", "SMOGN"]
synth_opt = st.selectbox("Choose a synthesizer", synth_opts)

if synth_opt == "Gaussian Copula":
    synthesizer = load_synthesizer(metadata)
    synthetic_data = get_synthetic_data(synthesizer, real_data, number)
elif synth_opt == "SMOTE" or synth_opt == "SMOGN":
    target_column = st.selectbox(
        'Choose the Target Column', real_data.columns, key="preserve_chosen_target", index=len(real_data.columns)-1, on_change=reset_run_btn)

    if synth_opt == "SMOTE":
        synthetic_data = get_smote_data(real_data, target_column)
    else:
        synthetic_data = get_smogn_data(real_data, target_column)

st.header("Synthetic Data")
show_df(synthetic_data)

# 1. perform basic validity checks
diagnostic = run_diagnostic(real_data, synthetic_data, metadata)
diag_props = diagnostic.get_properties()
show_df(diag_props)
st.plotly_chart(diagnostic.get_visualization('Data Validity'))

# 2. measure the statistical similarity
quality_report = evaluate_quality(real_data, synthetic_data, metadata)
quality_props = quality_report.get_properties()
show_df(quality_props)
for p in quality_props['Property']:
    st.plotly_chart(quality_report.get_visualization(p))

# 3. plot the data in streamlit tabs
tabs = st.tabs(list(real_data.columns))
for i in range(len(tabs)):
    fig = get_column_plot(
        real_data=real_data,
        synthetic_data=synthetic_data,
        metadata=metadata,
        column_name=real_data.columns[i]
    )

    tabs[i].plotly_chart(fig)

# save the synthetic data
if st.button("Save Synthetic Data"):
    synthetic_data.to_csv(PATH_TO_GEN_DATASET, index=None)
    st.success("Synthetic data saved successfully")
    update_file_tree(file_tree_ph, globals())