SCARLET / src / general_imports.py
general_imports.py
Raw
###################
### Description ###
###################
 
# This script is for general imports that are used in multiple scripts.
# Includes general helper functions, model mean and variance definitions, 
# and model creation functions with PyMC.
 
##############
### Author ###
##############

# Sam Crofts (sam.crofts@ed.ac.uk)
 
########################################################################################################################
 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from multiprocessing import Pool
from multiprocessing import cpu_count
import anndata as ad
import json
import pickle
import math
import random
import arviz as az
from scipy.stats import bootstrap
import pymc as pm
from sklearn.feature_selection import r_regression
from scipy.stats import bootstrap
from itertools import product
from functools import partial
from scipy.stats import norm
from scipy.stats import beta
from scipy.stats import ks_2samp
from scipy.stats import spearmanr
import contextlib
import os

##########################
#General helper functions#
##########################

# Function to uniformly sample participants from the data with regard to age 
# (as much as possible given the underlying distribution of ages in the data)
def sample_to_uniform_age(amdata, n_part):
    remaining_idx = list(amdata.var.index)
    used_idx = []
    for i in range(n_part):
        sampled_age = np.random.uniform(amdata.var.age.min(), amdata.var.age.max())
        near_idx = np.abs(amdata[:, remaining_idx].var.age
                - sampled_age).argmin()
        used_idx.append(remaining_idx[near_idx])
        remaining_idx.pop(near_idx)
    return amdata[:,used_idx]

# Function to individually match samples from one dataset to another based on age 
def match_ages(amdata_to_sample_from, amdata_to_match_on):
    n_part = amdata_to_match_on.shape[1]
    remaining_idx = list(amdata_to_sample_from.var.index)
    used_idx = []
    for i in range(n_part):
        sampled_age = amdata_to_match_on.var.age.iloc[i]
        #just keep those that appear in remaining_idx
        near_idx = np.abs(amdata_to_sample_from[:, remaining_idx].var.age
                - sampled_age).argmin()
        used_idx.append(remaining_idx[near_idx])
        remaining_idx.pop(near_idx)
    return amdata_to_sample_from[:,used_idx]

# Function to individually match samples from one dataset to another based on both age and sex
def match_ages_and_sex(amdata_to_sample_from, amdata_to_match_on):
    n_part = amdata_to_match_on.shape[1]
    remaining_idx = list(amdata_to_sample_from.var.index)
    used_idx = []
    for i in range(n_part):
        sampled_age = amdata_to_match_on.var.age.iloc[i]
        sex = amdata_to_match_on.var.sex.iloc[i]
        amdata_to_sample_from_sex_matched = amdata_to_sample_from[:, amdata_to_sample_from.var.sex==sex]
        remaining_idx_temp = list(amdata_to_sample_from_sex_matched.var.index)
        #just keep those that apepar in remaining_idx
        remaining_idx_temp = [x for x in remaining_idx_temp if x in remaining_idx]
        near_idx = np.abs(amdata_to_sample_from_sex_matched[:, remaining_idx_temp].var.age
                - sampled_age).argmin()
        used_idx.append(remaining_idx_temp[near_idx])
        remaining_idx.pop(near_idx)
    return amdata_to_sample_from[:,used_idx]

# Function to ensure that the mean and variance are valid for a Beta distribution
def ensure_valid_beta_mu_var(mu, var):
    """
    Ensure that the mean (mu) and variance (var) are valid for a Beta distribution.
    The variance must be less than mu * (1 - mu) and greater than 0.
    
    Parameters:
    -----------
    mu : array-like
        Mean values to be checked.
    var : array-like
        Variance values to be checked.
        
    Returns:
    --------
    mu_clipped : array-like
        Clipped mean values within valid range [0.001, 0.999].
    var_clipped : array-like
        Clipped variance values within valid range [1e-6, max_variance].
    """
    
    # Clip mean to be within (0, 1)
    mu_clipped = pm.math.clip(mu, 0.001, 0.999)
    
    # Calculate maximum allowed variance for Beta distribution
    max_variance = mu_clipped * (1 - mu_clipped) * 0.99  # 99% of theoretical max
    
    # Clip variance to be within (0, max_variance)
    var_clipped = pm.math.clip(var, 1e-10, max_variance)
    
    return mu_clipped, var_clipped


#####################################
#Model mean and variance definitions#
#####################################

###Model of methylation changes via stem cell divisons. See mathematical details in supplementary###

##Parameters##

# N = number of stem cells
# M = number of methylated stem cells
# t = time (years)
# s = division rate per stem cell per year
# n = eta, theoretical final methylation level at t=infinity, equal to (u)/(m+u)
# m = P_(M->U), probability of a methylated site becoming unmethylated per division
# u = P_(U->M), probability of an unmethylated site becoming methylated per division
# w = omega, equal to m+u
# Z = total number of methylated cells relative to N (i.e. Z(t) = M(t)/N(t))
# p = theoretical initial methylation level at t=0
# c = initial variance of methylated stem cells, M, at t=0

def mean_Z(t, s, n, w, p):
    return n + np.exp(-2*s*t*w)*(p-n)

def cov_NM(t, s, N, n, w, p):
    return 2*s*N*t*(n+np.exp(-2*s*t*w)*(1-w)*(p-n))

def var_N(t,s,N):
    return 2*s*N*t

def var_M(t, s, N, n, w, p, c):

    A0 = N*n*(1
              +np.power(w, 2)
              -n*(1
                  + np.power(w, 2)
                  - 4*s*w*t)
              )

    A2 = 2*N*(n-p)*(-1
                    + w
                    - np.power(w, 2)
                    + 2*n*(1
                          - w*(1+2*s*t)
                          + np.power(w,2)*(1+2*s*t)))

    A4 = (np.power(n,2)*N*(-3 + 4*w - 3*np.power(w,2))
          + 2*(c*w + N*p*(-1 + w - np.power(w, 2)))
          + n*N*(1+4*p - 2*w*(1+2*p)+np.power(w,2)*(1+4*p))          
    )

    vM = (A0 + A2*np.exp(-2*s*t*w) + A4*np.exp(-4*s*t*w))/(2*w)

    return vM

def var_Z_order_1(t, s, N, n, w, p, c):
    
    mZ = mean_Z(t, s, n, w, p)
    covNM = cov_NM(t, s, N, n, w, p)
    vN = var_N(t, s, N)
    vM = var_M(t, s, N, n, w, p, c)

    return (np.power(mZ, 2)*vN + vM - 2*mZ*covNM)/(np.power(N,2))

# Function to get prior parameters from data and prep data for modelling
def get_prior_params_and_prep(amdata_load):
    
    # Ensure age is 0 at the youngest individual
    amdata_load.var.age = amdata_load.var.age - amdata_load.var.age.min()

    # Get the youngest individuals to make the prior for the starting mean and sd
    early_idx = amdata_load[:, amdata_load.var.age<np.quantile(amdata_load.var.age, 0.05)].var.index

    #starting mean and sd
    p_mean = amdata_load[:, amdata_load.var.index.isin(early_idx)].X.mean(axis=1)
    early_sd = amdata_load[:, amdata_load.var.index.isin(early_idx)].X.std(axis=1)

    amdata_load.obs['p_mean'] = p_mean
    amdata_load.obs['early_sd'] = early_sd

    return amdata_load

##############################
###Model creation with PyMC###
##############################

#conditional model (fixed N and s)
def make_cond_beta_model(amdata_load, N, s):

    amdata = get_prior_params_and_prep(amdata_load).copy()

    P_mean = np.array(amdata.obs.p_mean)
    P_sd = np.array(amdata.obs.early_sd)
    early_sd = np.array(amdata.obs.early_sd)

    # Establish coordinate system
    coords = {'sites': amdata.obs.index.values,
            'participants': amdata.var.index.values}

    ### Model definition ###
    with pm.Model(coords=coords) as sym_model:

        data = amdata.X.T
        t = np.array(amdata.var.age)[:, None]

        ## Define priors ###
        
        # use prior knowledge to define priors on P, eta and c
        P = pm.TruncatedNormal('p', mu=P_mean, 
                                    sigma =P_sd,
                                    lower=0.001, upper=0.999, dims='sites')

        n = pm.Uniform('eta', lower=0.001, upper=0.999, dims='sites')
        
        #initial standard deviation of the proportion of methylated stem cells
        std_prop_init = pm.TruncatedNormal("std_prop_init", 
                                        mu=early_sd,
                                        sigma=0.02,
                                        lower=1e-15, upper=0.2,
                                        dims='sites')

        c = pm.Deterministic('var_sc_init',
                            np.power(N*std_prop_init, 2),
                            dims='sites')
        
        # Define general priors for omega using log scale 
        log_w = pm.TruncatedNormal('log(omega)',
                                mu=-2,
                                sigma=6,
                                lower=-15, upper=np.log10(1.99), dims='sites')
                                
        w = pm.Deterministic('omega', np.exp(log_w))

        # Evolution of mean and variance of Z using delta method
        mean = mean_Z(t, s, n, w, P)
        variance = var_Z_order_1(t, s, N, n, w, P, c)

        # Make sure mean and variance are valid for Beta distribution
        mean = ensure_valid_beta_mu_var(mean, variance)[0]
        variance = ensure_valid_beta_mu_var(mean, variance)[1]

        # # Define likelihood
        pm.Beta("m-values",
            mu = mean,
            sigma = np.sqrt(variance),
            dims=("participants", "sites"),
            observed = data)

        return sym_model

# Function to make the null model (fixed mean and variance over time)    
def make_null_model(amdata_load):

    #Get appropriate priors from data and set ages to start at 0
    amdata = get_prior_params_and_prep(amdata_load).copy()
    P_mean = np.array(amdata.obs.p_mean)
    P_sd = np.array(amdata.obs.early_sd)
    early_sd = np.array(amdata.obs.early_sd)

    # Establish coordinate system
    coords = {'sites': amdata.obs.index.values,
            'participants': amdata.var.index.values}

    ### Model definition ###
    with pm.Model(coords=coords) as sym_model:

        data = amdata.X.T

        ### Define priors ###
        
        # use prior knowledge to define priors on P, eta and c
        P = pm.TruncatedNormal('p', mu=P_mean, 
                                    sigma =P_sd,
                                    lower=0.001, upper=0.999, dims='sites')

        std_prop_init = pm.TruncatedNormal("std_prop_init", 
                                        mu=early_sd,
                                        sigma=0.02,
                                        lower=1e-6, upper=0.1,
                                        dims='sites')

        variance = pm.Deterministic('variance', np.power(std_prop_init, 2), dims='sites')

        #make sure mean and variance are valid for Beta distribution
        mean = ensure_valid_beta_mu_var(P, variance)[0]
        variance = ensure_valid_beta_mu_var(P, variance)[1]

        # # Define likelihood
        pm.Beta("m-values",
            mu = mean,
            sigma = np.sqrt(variance),
            dims=("participants", "sites"),
            observed = data)

        return sym_model

# Function to make the full, unconditional model
def make_mcmc_order1(amdata_load):

    #Get appropriate priors from data and set ages to start at 0
    amdata = get_prior_params_and_prep(amdata_load).copy()
    P_mean = np.array(amdata.obs.p_mean)
    P_sd = np.array(amdata.obs.early_sd)
    early_sd = np.array(amdata.obs.early_sd)

    # Establish coordinate system
    coords = {'sites': amdata.obs.index.values,
            'participants': amdata.var.index.values}

    ### Model definition ###
    with pm.Model(coords=coords) as sym_model:

        data = amdata.X.T
        t = np.array(amdata.var.age)[:, None]

        ### Define priors ###
        
        # Define the order of magnitude of N/s and N
        log_Ns = pm.TruncatedNormal('log10(N/s)', mu=5, sigma=2, lower=1, upper=6)
        Ns = pm.Deterministic('N/s', np.power(10, log_Ns))
        log_N = pm.TruncatedNormal('log10(N)', mu=5, sigma=2, lower=1, upper=6)
        N = pm.Deterministic('N', np.power(10, log_N))
        s = pm.Deterministic('s', N/Ns)

        # use prior knowledge to define priors on P, eta and c
        P = pm.TruncatedNormal('p', mu=P_mean, 
                                    sigma =P_sd,
                                    lower=0.0001, upper=1, dims='sites')

        n = pm.Uniform('eta', lower=0.0001, upper=0.999, dims='sites')
        
        std_prop_init = pm.TruncatedNormal("std_prop_init", 
                                        mu=early_sd,
                                        sigma=0.02,
                                        lower=1e-6, upper=0.5,
                                        dims='sites')

        c = pm.Deterministic('var_sc_init',
                            np.power(N*std_prop_init, 2),
                            dims='sites')
        
        # Define general priors for omega using log scale 
        log_w = pm.TruncatedNormal('log(omega)',
                                mu=-2,
                                sigma=2,
                                lower=-6, upper=np.log10(1.99), dims='sites')

        w = pm.Deterministic('omega', np.exp(log_w))

        # Evolution of mean and variance of Z using delta method
        mean = mean_Z(t, s, n, w, P)
        variance = var_Z_order_1(t, s, N, n, w, P, c)

        # Ensure mean and variance are valid for Beta distribution
        mean = ensure_valid_beta_mu_var(mean, variance)[0]
        variance = ensure_valid_beta_mu_var(mean, variance)[1]

        # # Define likelihood
        pm.Beta("m-values",
            mu = mean,
            sigma = np.sqrt(variance),
            dims=("participants", "sites"),
            observed = data)

        return sym_model

# Function to make a simple linear model
def make_simple_linear_model(adata):
    """
    Simple linear model: y = intercept + slope * age
    Uses Normal likelihood (standard linear regression)
    
    Parameters:
    -----------
    adata : AnnData object with:
        - adata.X.T : data matrix (participants x sites)  
        - adata.var.age : age values for participants
        - adata.obs.index : site names
        - adata.var.index : participant names
    """
    
    # Extract data
    data = adata.X.T  # participants x sites
    ages = np.array(adata.var.age)[:, None]  # participants x 1
    
    # Normalize ages to [0, 1] for numerical stability
    age_min, age_max = ages.min(), ages.max()
    ages_norm = (ages - age_min) / (age_max - age_min)
    
    n_participants, n_sites = data.shape
    
    # Data-informed prior parameters
    data_mean = np.mean(data, axis=0)  # mean per site
    data_std = np.std(data, axis=0)    # standard deviation per site
    
    # Print diagnostic information about priors
    print("DATA-INFORMED PRIORS:")
    print("=" * 40)
    print(f"Sites: {n_sites}, Participants: {n_participants}")
    print(f"Age range: {age_min:.1f} to {age_max:.1f} years")
    print(f"Methylation range: {data.min():.3f} to {data.max():.3f}")
    print(f"Intercept prior mean: {data_mean.mean():.3f}")
    print(f"Data standard deviation: {data_std.mean():.3f}")
    
    # Calculate slope prior scale
    data_range = np.max(data, axis=0) - np.min(data, axis=0)
    slope_scale = np.median(data_range) / 2
    slope_scale = max(slope_scale, 0.1)
    print(f"Slope prior scale: {slope_scale:.3f}")
    print()
    
    # Set up coordinates
    coords = {
        'sites': adata.obs.index.values,
        'participants': adata.var.index.values
    }
    
    with pm.Model(coords=coords) as model:
        
        # === PRIORS ===
        
        # Intercept: Normal prior centered on observed site means
        intercept = pm.Normal('intercept', 
                             mu=data_mean,      # Site-specific mean
                             sigma=data_std,    # Site-specific std
                             dims='sites')
        
        # Slope: change in methylation over normalized age range [0,1]
        slope = pm.Normal('slope', 
                         mu=0,              # No prior assumption about direction
                         sigma=slope_scale, # Data-informed scale
                         dims='sites')
        
        # Standard deviation for Normal distribution
        sigma = pm.HalfNormal('sigma',
                             sigma=data_std,  # Data-informed scale
                             dims='sites')
        
        # === LINEAR MODEL ===
        
        # Mean methylation: intercept + slope * normalized_age
        mu = intercept + slope * ages_norm
        
        # === LIKELIHOOD ===
        
        # Observed data with Normal distribution
        pm.Normal('m-values',
                 mu=mu,
                 sigma=sigma,
                 dims=('participants', 'sites'),
                 observed=data)
        
    return model


# Function to define a version of our model that includes all mammals at once (the joint model)
def make_mcmc_beta_model_stacked(data_matrix, t_matrix, data_list, min_samples, n_sites, model_type):

    # Establish coordinate system
    coords = {
        'samples': [f'sample_{i}' for i in range(min_samples)],
        'sites': [f'site_{i}' for i in range(n_sites)],
        'species': [data.uns['common_name'] for data in data_list]
    }

    #starting mean and sd
    p_mean = np.stack([data[:,data.var.age<np.quantile(data.var.age, 0.1)].X.mean(axis=1) 
                       for data in data_list],axis=-1)    
    
    p_sd = np.stack([data[:,data.var.age<np.quantile(data.var.age, 0.1)].X.std(axis=1)
                   for data in data_list], axis=-1)

    #if any of the standard deviations are 0, set them to 0.01
    p_sd[p_sd<=0] = 0.01
    
    ### Model definition ###
    with pm.Model(coords=coords) as sym_model:

        #Set observed data
        data = pm.MutableData('data', data_matrix)
        t = pm.MutableData('age', t_matrix)
        lifespan = np.array([data_l.uns['lifespan'] for data_l in data_list])
        lifespan = pm.Data('lifespan', lifespan, dims='species')

        ### set Ns priors depending on the model type ###

        # N/s can be anything for each species
        if model_type == 'free_ns':
            log_Ns = pm.TruncatedNormal('log10(N_on_s)', mu=5, sigma=2, lower=1, upper=6, dims='species')
           
        # N/s scales with lifespan
        if model_type == 'scaled_ns':
            log_Ns = pm.Deterministic('log10(N_on_s)', 1.6016*np.log10(lifespan)+0.9866, dims='species')

        # N/s is the same for all species
        if model_type == 'single_ns':
            log_Ns = pm.TruncatedNormal('log10(N_on_s)', mu=5, sigma=2, lower=1, upper=6)


        elif model_type == 'hierarchical_scaled_ns':
            # Estimate the scaling relationship parameters (from separate model)
            slope = pm.Normal('scaling_slope', mu=1.42, sigma=0.2)
            intercept = pm.Normal('scaling_intercept', mu=0.97, sigma=0.3)
            
            # Apply the scaling relationship with estimated parameters
            log_Ns = pm.Deterministic('log10(N_on_s)', 
                                      slope * np.log10(lifespan) + intercept, 
                                      dims='species')

        #set other priors
        Ns = pm.Deterministic('N_on_s', np.power(10, log_Ns))
        log_N = pm.TruncatedNormal('log10(N)', mu=5, sigma=2, lower=1, upper=6, dims='species')
        N = pm.Deterministic('N', np.power(10, log_N))
        s = pm.Deterministic('s', N/Ns)

        # use prior knowledge to define priors on P, eta and c
        P = pm.TruncatedNormal('p', mu=p_mean, 
                                    sigma =p_sd,
                                    lower=0.001, upper=0.999, dims=['sites', 'species'])

        n = pm.Uniform('eta', lower=0.0001, upper=0.999, dims=['sites', 'species'])

        #initial standard deviation of the proportion of methylated stem cells
        std_prop_init = pm.TruncatedNormal("std_prop_init", 
                                        mu=p_sd,
                                        sigma=0.02,
                                        lower=1e-6, upper=0.5,
                                        dims=['sites', 'species'])

        c = pm.Deterministic('var_sc_init',
                            np.power(N*std_prop_init, 2),
                            dims=['sites', 'species'])
        
        # Define general priors for omega using log scale 
        log_w = pm.TruncatedNormal('log(omega)',
                                mu=-2,
                                sigma=2,
                                lower=-6, upper=np.log10(1.99), dims=['sites', 'species'])

        w = pm.Deterministic('omega', np.exp(log_w))

        # Evolution of mean and variance of Z using delta method
        mean = mean_Z(t, s, n, w, P)
        variance = var_Z_order_1(t, s, N, n, w, P, c)

        # Ensure mean and variance are valid for Beta distribution
        mean = ensure_valid_beta_mu_var(mean, variance)[0]
        variance = ensure_valid_beta_mu_var(mean, variance)[1]
        
        pm.Beta("m-values", mu = mean, sigma = np.sqrt(variance),
                dims=("samples", "sites", "species"), observed=data)

    return sym_model