import streamlit as st
import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.lite import SingleTablePreset
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot
from tools import *
from sdv.single_table import CTGANSynthesizer, GaussianCopulaSynthesizer
from sdv.sequential import PARSynthesizer
from imblearn.over_sampling import SMOTE
import smogn
_, file_tree_ph = sidebar(globals())
st.title("Generate Synthetic Data")
exp_id = st.session_state.get('exp_id')
if exp_id is None:
st.write("Please select an experiment first.")
st.stop()
# Read the uploaded dataset into a DataFrame
real_data = pd.read_csv(PATH_TO_TRAIN_DATASET)
# Add patient id to rows with the same age, gender, weight, height
# patient_id = (real_data.groupby(['age', 'gender', 'weight', 'height'])
# .ngroup() + 1)
# real_data.insert(0, 'patient_id', patient_id)
# Display the dataset
st.header("Real Data")
show_df(real_data)
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(real_data)
# st.json(metadata.to_dict())
metadata.visualize(output_filepath='./assets/my_metadata.png')
st.image('./assets/my_metadata.png', width=300)
# Input field for the user to specify the number of rows for synthetic data generation
number = st.number_input('Number of rows', min_value=0, step=1000)
if number == 0:
st.stop()
@st.cache_resource()
def load_synthesizer(_metadata):
# synthesizer = SingleTablePreset(
# metadata,
# name='FAST_ML')
# metadata.update_column(
# column_name='patient_id',
# sdtype='id',
# regex_format="\d+")
# metadata.set_sequence_key(column_name='patient_id')
# synthesizer = PARSynthesizer(
# metadata, # required
# enforce_min_max_values=True,
# enforce_rounding=False,
# context_columns=[]
# )
synthesizer = GaussianCopulaSynthesizer(
metadata, # required
enforce_min_max_values=True,
enforce_rounding=False,
# numerical_distributions={
# 'amenities_fee': 'beta',
# 'checkin_date': 'uniform'
# },
default_distribution='norm'
)
return synthesizer
# return CTGANSynthesizer(_metadata)
@st.cache_data()
def get_synthetic_data(_synthesizer, real_data, number):
_synthesizer.fit(
data=real_data
)
# Get num of rows
synthetic_data = _synthesizer.sample(
num_rows=number
)
return synthetic_data
@st.cache_data()
def get_smote_data(_real_data, _target_column):
sm = SMOTE(random_state=42)
X = _real_data.drop(_target_column, axis=1)
y = _real_data[_target_column]
X_res, y_res = sm.fit_resample(X, y)
synthetic_data = pd.concat([X_res, y_res], axis=1)
return synthetic_data
@st.cache_data()
def get_smogn_data(real_data, target_column):
synthetic_data = smogn.smoter(
data=real_data, y=target_column)
return synthetic_data
synth_opts = ["Gaussian Copula", "SMOTE", "SMOGN"]
synth_opt = st.selectbox("Choose a synthesizer", synth_opts)
if synth_opt == "Gaussian Copula":
synthesizer = load_synthesizer(metadata)
synthetic_data = get_synthetic_data(synthesizer, real_data, number)
elif synth_opt == "SMOTE" or synth_opt == "SMOGN":
target_column = st.selectbox(
'Choose the Target Column', real_data.columns, key="preserve_chosen_target", index=len(real_data.columns)-1, on_change=reset_run_btn)
if synth_opt == "SMOTE":
synthetic_data = get_smote_data(real_data, target_column)
else:
synthetic_data = get_smogn_data(real_data, target_column)
st.header("Synthetic Data")
show_df(synthetic_data)
# 1. perform basic validity checks
diagnostic = run_diagnostic(real_data, synthetic_data, metadata)
diag_props = diagnostic.get_properties()
show_df(diag_props)
st.plotly_chart(diagnostic.get_visualization('Data Validity'))
# 2. measure the statistical similarity
quality_report = evaluate_quality(real_data, synthetic_data, metadata)
quality_props = quality_report.get_properties()
show_df(quality_props)
for p in quality_props['Property']:
st.plotly_chart(quality_report.get_visualization(p))
# 3. plot the data in streamlit tabs
tabs = st.tabs(list(real_data.columns))
for i in range(len(tabs)):
fig = get_column_plot(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata,
column_name=real_data.columns[i]
)
tabs[i].plotly_chart(fig)
# save the synthetic data
if st.button("Save Synthetic Data"):
synthetic_data.to_csv(PATH_TO_GEN_DATASET, index=None)
st.success("Synthetic data saved successfully")
update_file_tree(file_tree_ph, globals())