###################
### Description ###
###################
# This script is for proprocessing the GenScot data (already in methylation beta values in AnnData format, with
# CpGs as obs and samples as vars). Specifically, it calculates various CpG level statistcics (e.g. spearman r)
# and adds them to the Anndata object.
# Inputs: Anndata object of GenScot methylation beta values (wave 3), with CpGs as obs and samples as vars.
# Outputs: Anndata object with CpG-level statistics added to obs.
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
##################################
##################################
#Imports
import sys
sys.path.append("../..") # fix to import modules from root
from src.general_imports import *
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy.signal import find_peaks
from scipy.stats import spearmanr
#Import Wave 3 GenScot data
all_data = ad.read_h5ad('/exports/igmm/eddie/tchandra-lab/EJYang/methylclock/GS_mval/wave3_meta.h5ad', 'r')
amdata = all_data.to_memory()
median_age = np.median(amdata.var.age.values)
#initialize columns
amdata.obs['r2'] = 0
amdata.obs['beta'] = 0
amdata.obs['bp_pval'] = 0
amdata.obs['spearman_stat'] = 0
amdata.obs['spearman_pval'] = 0
amdata.obs['beta_pval'] = 0
amdata.obs['variance_overall'] = 0
amdata.obs['variance_first_half'] = 0
amdata.obs['variance_second_half'] = 0
amdata.obs['mean_meth'] = 0
amdata.obs['var_change'] = 0
amdata.obs['var_change_perc'] = 0
amdata.obs['white_pval'] = 0
amdata.obs['n_peaks'] = 0
amdata.obs['slope_below_median'] = 0
amdata.obs['slope_above_median'] = 0
#loop through sites
for site in tqdm(amdata.obs.index):
#fit linear model
data = {'age': amdata.var.age.values, 'meth': amdata[site].X.T[:,0]}
df = pd.DataFrame(data)
model = smf.ols('meth ~ age', data=df).fit()
residuals = model.resid
beta_overall = model.params["age"]
r2_overall = model.rsquared_adj
beta_pval = model.pvalues["age"]
intercept_overall = model.params["Intercept"]
mean_meth = np.mean(amdata[site].X.T)
# Perform Breusch-Pagan test for heteroscedasticity
_, p_value_bp, _, _ = sm.stats.diagnostic.het_breuschpagan(residuals, model.model.exog)
below_median_values = amdata[site].X.T[amdata.var.age.values < median_age]
above_median_values = amdata[site].X.T[amdata.var.age.values >= median_age]
# Calculate variances for each half
variance_first_half = np.var(below_median_values)
variance_second_half = np.var(above_median_values)
variance_overall = np.var(amdata[site].X.T)
# Spearman correlation
spearman_pval = spearmanr(amdata.var.age.values, amdata[site].X.T)[1]
spearman_stat = spearmanr(amdata.var.age.values, amdata[site].X.T)[0]
# Perform White test for heteroscedasticity
_, p_value_white, _, _ = sm.stats.diagnostic.het_white(residuals, model.model.exog)
#add to the anndata object
amdata.obs.loc[site, 'r2'] = r2_overall
amdata.obs.loc[site, 'beta'] = beta_overall
amdata.obs.loc[site, 'bp_pval'] = p_value_bp
amdata.obs.loc[site, 'spearman_stat'] = spearman_stat
amdata.obs.loc[site, 'spearman_pval'] = spearman_pval
amdata.obs.loc[site, 'beta_pval'] = beta_pval
amdata.obs.loc[site, 'variance_overall'] = variance_overall
amdata.obs.loc[site, 'variance_first_half'] = variance_first_half
amdata.obs.loc[site, 'variance_second_half'] = variance_second_half
amdata.obs.loc[site, 'mean_meth'] = mean_meth
amdata.obs.loc[site, 'var_change'] = variance_second_half - variance_first_half
amdata.obs.loc[site, 'var_change_perc'] = ((variance_second_half - variance_first_half)/variance_overall)*100
amdata.obs.loc[site, 'white_pval'] = p_value_white
amdata.obs.loc[site, 'intercept'] = intercept_overall
#Now let's get rid of sites that are multimodal
site_data = amdata[site].X.flatten()
dens = sm.nonparametric.KDEUnivariate(site_data)
#Make bandwidth a function of the SD of the data
sd = np.sqrt(np.var(site_data))
ninety_five_pc_of_data_range = 4*sd
bw = 0.05*ninety_five_pc_of_data_range
#making sure we get no errors
if bw == 0:
bw = 0.1
dens.fit(kernel='gau', bw=bw)
x = np.linspace(-0.1,1.1,100) #restrict range to (0,1)
y = dens.evaluate(x)
y = y/max(y)
amdata.obs.loc[site, 'n_peaks'] = find_peaks(y, prominence=0.001)[0].shape[0]
#for each site, calculate slope above and below median age
#data below median age
below_median_mask = amdata.var.age.values < median_age
data_below = {'age': amdata.var.age.values[below_median_mask], 'meth': amdata[site].X.T[below_median_mask].flatten()}
df_below = pd.DataFrame(data_below)
model_below = smf.ols('meth ~ age', data=df_below).fit()
slope_below = model_below.params["age"]
#data above median age
above_median_mask = amdata.var.age.values >= median_age
data_above = {'age': amdata.var.age.values[above_median_mask], 'meth': amdata[site].X.T[above_median_mask].flatten()}
df_above = pd.DataFrame(data_above)
model_above = smf.ols('meth ~ age', data=df_above).fit()
slope_above = model_above.params["age"]
#add to anndata
amdata.obs.loc[site, 'slope_below_median'] = slope_below
amdata.obs.loc[site, 'slope_above_median'] = slope_above
#save
with open('../data/genscot_full_with_site_details.pk', 'wb') as f:
pickle.dump(amdata, f)