###################
### Description ###
###################
# This script is for general imports that are used in multiple scripts.
# Includes general helper functions, model mean and variance definitions,
# and model creation functions with PyMC.
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
########################################################################################################################
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from multiprocessing import Pool
from multiprocessing import cpu_count
import anndata as ad
import json
import pickle
import math
import random
import arviz as az
from scipy.stats import bootstrap
import pymc as pm
from sklearn.feature_selection import r_regression
from scipy.stats import bootstrap
from itertools import product
from functools import partial
from scipy.stats import norm
from scipy.stats import beta
from scipy.stats import ks_2samp
from scipy.stats import spearmanr
import contextlib
import os
##########################
#General helper functions#
##########################
# Function to uniformly sample participants from the data with regard to age
# (as much as possible given the underlying distribution of ages in the data)
def sample_to_uniform_age(amdata, n_part):
remaining_idx = list(amdata.var.index)
used_idx = []
for i in range(n_part):
sampled_age = np.random.uniform(amdata.var.age.min(), amdata.var.age.max())
near_idx = np.abs(amdata[:, remaining_idx].var.age
- sampled_age).argmin()
used_idx.append(remaining_idx[near_idx])
remaining_idx.pop(near_idx)
return amdata[:,used_idx]
# Function to individually match samples from one dataset to another based on age
def match_ages(amdata_to_sample_from, amdata_to_match_on):
n_part = amdata_to_match_on.shape[1]
remaining_idx = list(amdata_to_sample_from.var.index)
used_idx = []
for i in range(n_part):
sampled_age = amdata_to_match_on.var.age.iloc[i]
#just keep those that appear in remaining_idx
near_idx = np.abs(amdata_to_sample_from[:, remaining_idx].var.age
- sampled_age).argmin()
used_idx.append(remaining_idx[near_idx])
remaining_idx.pop(near_idx)
return amdata_to_sample_from[:,used_idx]
# Function to individually match samples from one dataset to another based on both age and sex
def match_ages_and_sex(amdata_to_sample_from, amdata_to_match_on):
n_part = amdata_to_match_on.shape[1]
remaining_idx = list(amdata_to_sample_from.var.index)
used_idx = []
for i in range(n_part):
sampled_age = amdata_to_match_on.var.age.iloc[i]
sex = amdata_to_match_on.var.sex.iloc[i]
amdata_to_sample_from_sex_matched = amdata_to_sample_from[:, amdata_to_sample_from.var.sex==sex]
remaining_idx_temp = list(amdata_to_sample_from_sex_matched.var.index)
#just keep those that apepar in remaining_idx
remaining_idx_temp = [x for x in remaining_idx_temp if x in remaining_idx]
near_idx = np.abs(amdata_to_sample_from_sex_matched[:, remaining_idx_temp].var.age
- sampled_age).argmin()
used_idx.append(remaining_idx_temp[near_idx])
remaining_idx.pop(near_idx)
return amdata_to_sample_from[:,used_idx]
# Function to ensure that the mean and variance are valid for a Beta distribution
def ensure_valid_beta_mu_var(mu, var):
"""
Ensure that the mean (mu) and variance (var) are valid for a Beta distribution.
The variance must be less than mu * (1 - mu) and greater than 0.
Parameters:
-----------
mu : array-like
Mean values to be checked.
var : array-like
Variance values to be checked.
Returns:
--------
mu_clipped : array-like
Clipped mean values within valid range [0.001, 0.999].
var_clipped : array-like
Clipped variance values within valid range [1e-6, max_variance].
"""
# Clip mean to be within (0, 1)
mu_clipped = pm.math.clip(mu, 0.001, 0.999)
# Calculate maximum allowed variance for Beta distribution
max_variance = mu_clipped * (1 - mu_clipped) * 0.99 # 99% of theoretical max
# Clip variance to be within (0, max_variance)
var_clipped = pm.math.clip(var, 1e-10, max_variance)
return mu_clipped, var_clipped
#####################################
#Model mean and variance definitions#
#####################################
###Model of methylation changes via stem cell divisons. See mathematical details in supplementary###
##Parameters##
# N = number of stem cells
# M = number of methylated stem cells
# t = time (years)
# s = division rate per stem cell per year
# n = eta, theoretical final methylation level at t=infinity, equal to (u)/(m+u)
# m = P_(M->U), probability of a methylated site becoming unmethylated per division
# u = P_(U->M), probability of an unmethylated site becoming methylated per division
# w = omega, equal to m+u
# Z = total number of methylated cells relative to N (i.e. Z(t) = M(t)/N(t))
# p = theoretical initial methylation level at t=0
# c = initial variance of methylated stem cells, M, at t=0
def mean_Z(t, s, n, w, p):
return n + np.exp(-2*s*t*w)*(p-n)
def cov_NM(t, s, N, n, w, p):
return 2*s*N*t*(n+np.exp(-2*s*t*w)*(1-w)*(p-n))
def var_N(t,s,N):
return 2*s*N*t
def var_M(t, s, N, n, w, p, c):
A0 = N*n*(1
+np.power(w, 2)
-n*(1
+ np.power(w, 2)
- 4*s*w*t)
)
A2 = 2*N*(n-p)*(-1
+ w
- np.power(w, 2)
+ 2*n*(1
- w*(1+2*s*t)
+ np.power(w,2)*(1+2*s*t)))
A4 = (np.power(n,2)*N*(-3 + 4*w - 3*np.power(w,2))
+ 2*(c*w + N*p*(-1 + w - np.power(w, 2)))
+ n*N*(1+4*p - 2*w*(1+2*p)+np.power(w,2)*(1+4*p))
)
vM = (A0 + A2*np.exp(-2*s*t*w) + A4*np.exp(-4*s*t*w))/(2*w)
return vM
def var_Z_order_1(t, s, N, n, w, p, c):
mZ = mean_Z(t, s, n, w, p)
covNM = cov_NM(t, s, N, n, w, p)
vN = var_N(t, s, N)
vM = var_M(t, s, N, n, w, p, c)
return (np.power(mZ, 2)*vN + vM - 2*mZ*covNM)/(np.power(N,2))
# Function to get prior parameters from data and prep data for modelling
def get_prior_params_and_prep(amdata_load):
# Ensure age is 0 at the youngest individual
amdata_load.var.age = amdata_load.var.age - amdata_load.var.age.min()
# Get the youngest individuals to make the prior for the starting mean and sd
early_idx = amdata_load[:, amdata_load.var.age<np.quantile(amdata_load.var.age, 0.05)].var.index
#starting mean and sd
p_mean = amdata_load[:, amdata_load.var.index.isin(early_idx)].X.mean(axis=1)
early_sd = amdata_load[:, amdata_load.var.index.isin(early_idx)].X.std(axis=1)
amdata_load.obs['p_mean'] = p_mean
amdata_load.obs['early_sd'] = early_sd
return amdata_load
##############################
###Model creation with PyMC###
##############################
#conditional model (fixed N and s)
def make_cond_beta_model(amdata_load, N, s):
amdata = get_prior_params_and_prep(amdata_load).copy()
P_mean = np.array(amdata.obs.p_mean)
P_sd = np.array(amdata.obs.early_sd)
early_sd = np.array(amdata.obs.early_sd)
# Establish coordinate system
coords = {'sites': amdata.obs.index.values,
'participants': amdata.var.index.values}
### Model definition ###
with pm.Model(coords=coords) as sym_model:
data = amdata.X.T
t = np.array(amdata.var.age)[:, None]
## Define priors ###
# use prior knowledge to define priors on P, eta and c
P = pm.TruncatedNormal('p', mu=P_mean,
sigma =P_sd,
lower=0.001, upper=0.999, dims='sites')
n = pm.Uniform('eta', lower=0.001, upper=0.999, dims='sites')
#initial standard deviation of the proportion of methylated stem cells
std_prop_init = pm.TruncatedNormal("std_prop_init",
mu=early_sd,
sigma=0.02,
lower=1e-15, upper=0.2,
dims='sites')
c = pm.Deterministic('var_sc_init',
np.power(N*std_prop_init, 2),
dims='sites')
# Define general priors for omega using log scale
log_w = pm.TruncatedNormal('log(omega)',
mu=-2,
sigma=6,
lower=-15, upper=np.log10(1.99), dims='sites')
w = pm.Deterministic('omega', np.exp(log_w))
# Evolution of mean and variance of Z using delta method
mean = mean_Z(t, s, n, w, P)
variance = var_Z_order_1(t, s, N, n, w, P, c)
# Make sure mean and variance are valid for Beta distribution
mean = ensure_valid_beta_mu_var(mean, variance)[0]
variance = ensure_valid_beta_mu_var(mean, variance)[1]
# # Define likelihood
pm.Beta("m-values",
mu = mean,
sigma = np.sqrt(variance),
dims=("participants", "sites"),
observed = data)
return sym_model
# Function to make the null model (fixed mean and variance over time)
def make_null_model(amdata_load):
#Get appropriate priors from data and set ages to start at 0
amdata = get_prior_params_and_prep(amdata_load).copy()
P_mean = np.array(amdata.obs.p_mean)
P_sd = np.array(amdata.obs.early_sd)
early_sd = np.array(amdata.obs.early_sd)
# Establish coordinate system
coords = {'sites': amdata.obs.index.values,
'participants': amdata.var.index.values}
### Model definition ###
with pm.Model(coords=coords) as sym_model:
data = amdata.X.T
### Define priors ###
# use prior knowledge to define priors on P, eta and c
P = pm.TruncatedNormal('p', mu=P_mean,
sigma =P_sd,
lower=0.001, upper=0.999, dims='sites')
std_prop_init = pm.TruncatedNormal("std_prop_init",
mu=early_sd,
sigma=0.02,
lower=1e-6, upper=0.1,
dims='sites')
variance = pm.Deterministic('variance', np.power(std_prop_init, 2), dims='sites')
#make sure mean and variance are valid for Beta distribution
mean = ensure_valid_beta_mu_var(P, variance)[0]
variance = ensure_valid_beta_mu_var(P, variance)[1]
# # Define likelihood
pm.Beta("m-values",
mu = mean,
sigma = np.sqrt(variance),
dims=("participants", "sites"),
observed = data)
return sym_model
# Function to make the full, unconditional model
def make_mcmc_order1(amdata_load):
#Get appropriate priors from data and set ages to start at 0
amdata = get_prior_params_and_prep(amdata_load).copy()
P_mean = np.array(amdata.obs.p_mean)
P_sd = np.array(amdata.obs.early_sd)
early_sd = np.array(amdata.obs.early_sd)
# Establish coordinate system
coords = {'sites': amdata.obs.index.values,
'participants': amdata.var.index.values}
### Model definition ###
with pm.Model(coords=coords) as sym_model:
data = amdata.X.T
t = np.array(amdata.var.age)[:, None]
### Define priors ###
# Define the order of magnitude of N/s and N
log_Ns = pm.TruncatedNormal('log10(N/s)', mu=5, sigma=2, lower=1, upper=6)
Ns = pm.Deterministic('N/s', np.power(10, log_Ns))
log_N = pm.TruncatedNormal('log10(N)', mu=5, sigma=2, lower=1, upper=6)
N = pm.Deterministic('N', np.power(10, log_N))
s = pm.Deterministic('s', N/Ns)
# use prior knowledge to define priors on P, eta and c
P = pm.TruncatedNormal('p', mu=P_mean,
sigma =P_sd,
lower=0.0001, upper=1, dims='sites')
n = pm.Uniform('eta', lower=0.0001, upper=0.999, dims='sites')
std_prop_init = pm.TruncatedNormal("std_prop_init",
mu=early_sd,
sigma=0.02,
lower=1e-6, upper=0.5,
dims='sites')
c = pm.Deterministic('var_sc_init',
np.power(N*std_prop_init, 2),
dims='sites')
# Define general priors for omega using log scale
log_w = pm.TruncatedNormal('log(omega)',
mu=-2,
sigma=2,
lower=-6, upper=np.log10(1.99), dims='sites')
w = pm.Deterministic('omega', np.exp(log_w))
# Evolution of mean and variance of Z using delta method
mean = mean_Z(t, s, n, w, P)
variance = var_Z_order_1(t, s, N, n, w, P, c)
# Ensure mean and variance are valid for Beta distribution
mean = ensure_valid_beta_mu_var(mean, variance)[0]
variance = ensure_valid_beta_mu_var(mean, variance)[1]
# # Define likelihood
pm.Beta("m-values",
mu = mean,
sigma = np.sqrt(variance),
dims=("participants", "sites"),
observed = data)
return sym_model
# Function to make a simple linear model
def make_simple_linear_model(adata):
"""
Simple linear model: y = intercept + slope * age
Uses Normal likelihood (standard linear regression)
Parameters:
-----------
adata : AnnData object with:
- adata.X.T : data matrix (participants x sites)
- adata.var.age : age values for participants
- adata.obs.index : site names
- adata.var.index : participant names
"""
# Extract data
data = adata.X.T # participants x sites
ages = np.array(adata.var.age)[:, None] # participants x 1
# Normalize ages to [0, 1] for numerical stability
age_min, age_max = ages.min(), ages.max()
ages_norm = (ages - age_min) / (age_max - age_min)
n_participants, n_sites = data.shape
# Data-informed prior parameters
data_mean = np.mean(data, axis=0) # mean per site
data_std = np.std(data, axis=0) # standard deviation per site
# Print diagnostic information about priors
print("DATA-INFORMED PRIORS:")
print("=" * 40)
print(f"Sites: {n_sites}, Participants: {n_participants}")
print(f"Age range: {age_min:.1f} to {age_max:.1f} years")
print(f"Methylation range: {data.min():.3f} to {data.max():.3f}")
print(f"Intercept prior mean: {data_mean.mean():.3f}")
print(f"Data standard deviation: {data_std.mean():.3f}")
# Calculate slope prior scale
data_range = np.max(data, axis=0) - np.min(data, axis=0)
slope_scale = np.median(data_range) / 2
slope_scale = max(slope_scale, 0.1)
print(f"Slope prior scale: {slope_scale:.3f}")
print()
# Set up coordinates
coords = {
'sites': adata.obs.index.values,
'participants': adata.var.index.values
}
with pm.Model(coords=coords) as model:
# === PRIORS ===
# Intercept: Normal prior centered on observed site means
intercept = pm.Normal('intercept',
mu=data_mean, # Site-specific mean
sigma=data_std, # Site-specific std
dims='sites')
# Slope: change in methylation over normalized age range [0,1]
slope = pm.Normal('slope',
mu=0, # No prior assumption about direction
sigma=slope_scale, # Data-informed scale
dims='sites')
# Standard deviation for Normal distribution
sigma = pm.HalfNormal('sigma',
sigma=data_std, # Data-informed scale
dims='sites')
# === LINEAR MODEL ===
# Mean methylation: intercept + slope * normalized_age
mu = intercept + slope * ages_norm
# === LIKELIHOOD ===
# Observed data with Normal distribution
pm.Normal('m-values',
mu=mu,
sigma=sigma,
dims=('participants', 'sites'),
observed=data)
return model
# Function to define a version of our model that includes all mammals at once (the joint model)
def make_mcmc_beta_model_stacked(data_matrix, t_matrix, data_list, min_samples, n_sites, model_type):
# Establish coordinate system
coords = {
'samples': [f'sample_{i}' for i in range(min_samples)],
'sites': [f'site_{i}' for i in range(n_sites)],
'species': [data.uns['common_name'] for data in data_list]
}
#starting mean and sd
p_mean = np.stack([data[:,data.var.age<np.quantile(data.var.age, 0.1)].X.mean(axis=1)
for data in data_list],axis=-1)
p_sd = np.stack([data[:,data.var.age<np.quantile(data.var.age, 0.1)].X.std(axis=1)
for data in data_list], axis=-1)
#if any of the standard deviations are 0, set them to 0.01
p_sd[p_sd<=0] = 0.01
### Model definition ###
with pm.Model(coords=coords) as sym_model:
#Set observed data
data = pm.MutableData('data', data_matrix)
t = pm.MutableData('age', t_matrix)
lifespan = np.array([data_l.uns['lifespan'] for data_l in data_list])
lifespan = pm.Data('lifespan', lifespan, dims='species')
### set Ns priors depending on the model type ###
# N/s can be anything for each species
if model_type == 'free_ns':
log_Ns = pm.TruncatedNormal('log10(N_on_s)', mu=5, sigma=2, lower=1, upper=6, dims='species')
# N/s scales with lifespan
if model_type == 'scaled_ns':
log_Ns = pm.Deterministic('log10(N_on_s)', 1.6016*np.log10(lifespan)+0.9866, dims='species')
# N/s is the same for all species
if model_type == 'single_ns':
log_Ns = pm.TruncatedNormal('log10(N_on_s)', mu=5, sigma=2, lower=1, upper=6)
elif model_type == 'hierarchical_scaled_ns':
# Estimate the scaling relationship parameters (from separate model)
slope = pm.Normal('scaling_slope', mu=1.42, sigma=0.2)
intercept = pm.Normal('scaling_intercept', mu=0.97, sigma=0.3)
# Apply the scaling relationship with estimated parameters
log_Ns = pm.Deterministic('log10(N_on_s)',
slope * np.log10(lifespan) + intercept,
dims='species')
#set other priors
Ns = pm.Deterministic('N_on_s', np.power(10, log_Ns))
log_N = pm.TruncatedNormal('log10(N)', mu=5, sigma=2, lower=1, upper=6, dims='species')
N = pm.Deterministic('N', np.power(10, log_N))
s = pm.Deterministic('s', N/Ns)
# use prior knowledge to define priors on P, eta and c
P = pm.TruncatedNormal('p', mu=p_mean,
sigma =p_sd,
lower=0.001, upper=0.999, dims=['sites', 'species'])
n = pm.Uniform('eta', lower=0.0001, upper=0.999, dims=['sites', 'species'])
#initial standard deviation of the proportion of methylated stem cells
std_prop_init = pm.TruncatedNormal("std_prop_init",
mu=p_sd,
sigma=0.02,
lower=1e-6, upper=0.5,
dims=['sites', 'species'])
c = pm.Deterministic('var_sc_init',
np.power(N*std_prop_init, 2),
dims=['sites', 'species'])
# Define general priors for omega using log scale
log_w = pm.TruncatedNormal('log(omega)',
mu=-2,
sigma=2,
lower=-6, upper=np.log10(1.99), dims=['sites', 'species'])
w = pm.Deterministic('omega', np.exp(log_w))
# Evolution of mean and variance of Z using delta method
mean = mean_Z(t, s, n, w, P)
variance = var_Z_order_1(t, s, N, n, w, P, c)
# Ensure mean and variance are valid for Beta distribution
mean = ensure_valid_beta_mu_var(mean, variance)[0]
variance = ensure_valid_beta_mu_var(mean, variance)[1]
pm.Beta("m-values", mu = mean, sigma = np.sqrt(variance),
dims=("samples", "sites", "species"), observed=data)
return sym_model