import streamlit as st import pandas as pd from sdv.metadata import SingleTableMetadata from sdv.lite import SingleTablePreset from sdv.evaluation.single_table import run_diagnostic, evaluate_quality from sdv.evaluation.single_table import get_column_plot from tools import * from sdv.single_table import CTGANSynthesizer, GaussianCopulaSynthesizer from sdv.sequential import PARSynthesizer from imblearn.over_sampling import SMOTE import smogn _, file_tree_ph = sidebar(globals()) st.title("Generate Synthetic Data") exp_id = st.session_state.get('exp_id') if exp_id is None: st.write("Please select an experiment first.") st.stop() # Read the uploaded dataset into a DataFrame real_data = pd.read_csv(PATH_TO_TRAIN_DATASET) # Add patient id to rows with the same age, gender, weight, height # patient_id = (real_data.groupby(['age', 'gender', 'weight', 'height']) # .ngroup() + 1) # real_data.insert(0, 'patient_id', patient_id) # Display the dataset st.header("Real Data") show_df(real_data) metadata = SingleTableMetadata() metadata.detect_from_dataframe(real_data) # st.json(metadata.to_dict()) metadata.visualize(output_filepath='./assets/my_metadata.png') st.image('./assets/my_metadata.png', width=300) # Input field for the user to specify the number of rows for synthetic data generation number = st.number_input('Number of rows', min_value=0, step=1000) if number == 0: st.stop() @st.cache_resource() def load_synthesizer(_metadata): # synthesizer = SingleTablePreset( # metadata, # name='FAST_ML') # metadata.update_column( # column_name='patient_id', # sdtype='id', # regex_format="\d+") # metadata.set_sequence_key(column_name='patient_id') # synthesizer = PARSynthesizer( # metadata, # required # enforce_min_max_values=True, # enforce_rounding=False, # context_columns=[] # ) synthesizer = GaussianCopulaSynthesizer( metadata, # required enforce_min_max_values=True, enforce_rounding=False, # numerical_distributions={ # 'amenities_fee': 'beta', # 'checkin_date': 'uniform' # }, default_distribution='norm' ) return synthesizer # return CTGANSynthesizer(_metadata) @st.cache_data() def get_synthetic_data(_synthesizer, real_data, number): _synthesizer.fit( data=real_data ) # Get num of rows synthetic_data = _synthesizer.sample( num_rows=number ) return synthetic_data @st.cache_data() def get_smote_data(_real_data, _target_column): sm = SMOTE(random_state=42) X = _real_data.drop(_target_column, axis=1) y = _real_data[_target_column] X_res, y_res = sm.fit_resample(X, y) synthetic_data = pd.concat([X_res, y_res], axis=1) return synthetic_data @st.cache_data() def get_smogn_data(real_data, target_column): synthetic_data = smogn.smoter( data=real_data, y=target_column) return synthetic_data synth_opts = ["Gaussian Copula", "SMOTE", "SMOGN"] synth_opt = st.selectbox("Choose a synthesizer", synth_opts) if synth_opt == "Gaussian Copula": synthesizer = load_synthesizer(metadata) synthetic_data = get_synthetic_data(synthesizer, real_data, number) elif synth_opt == "SMOTE" or synth_opt == "SMOGN": target_column = st.selectbox( 'Choose the Target Column', real_data.columns, key="preserve_chosen_target", index=len(real_data.columns)-1, on_change=reset_run_btn) if synth_opt == "SMOTE": synthetic_data = get_smote_data(real_data, target_column) else: synthetic_data = get_smogn_data(real_data, target_column) st.header("Synthetic Data") show_df(synthetic_data) # 1. perform basic validity checks diagnostic = run_diagnostic(real_data, synthetic_data, metadata) diag_props = diagnostic.get_properties() show_df(diag_props) st.plotly_chart(diagnostic.get_visualization('Data Validity')) # 2. measure the statistical similarity quality_report = evaluate_quality(real_data, synthetic_data, metadata) quality_props = quality_report.get_properties() show_df(quality_props) for p in quality_props['Property']: st.plotly_chart(quality_report.get_visualization(p)) # 3. plot the data in streamlit tabs tabs = st.tabs(list(real_data.columns)) for i in range(len(tabs)): fig = get_column_plot( real_data=real_data, synthetic_data=synthetic_data, metadata=metadata, column_name=real_data.columns[i] ) tabs[i].plotly_chart(fig) # save the synthetic data if st.button("Save Synthetic Data"): synthetic_data.to_csv(PATH_TO_GEN_DATASET, index=None) st.success("Synthetic data saved successfully") update_file_tree(file_tree_ph, globals())