Pertussis-seroprevalence / m-analyze_SCMs.R
m-analyze_SCMs.R
Raw
#######################################################################################################
# Cluster contact matrices
#######################################################################################################

rm(list = ls())
source("s-base_packages.R")
source("f-Project_M.R")
source("f-Upsize_M.R")
source("f-PlotMatrix.R")
source("f-compute_R0.R")
library(cluster)
library(factoextra)
library(NbClust)
library(adespatial)
library(clValid)
library(mgcv)
library(gratia)
library(lemon)
theme_set(theme_bw())
par(bty = "l", las = 1, lwd = 2)

# Load country-level social data from Mistry, 2021 ------------------------
path_dat <- "_data/_contact_matrices/Mistry_2021/"
l_files <- list.files(path = path_dat)

# Extract names of countries
country_nm <- map_chr(.x = l_files, .f = \(x) str_extract(string = x, pattern = "^(.*?)(?=_country_level)"))
dat_SCM <- vector(mode = "list", length = length(l_files))
names(dat_SCM) <- country_nm
dat_pop <- matrix(data = 0, nrow = 80, ncol = length(country_nm), 
                  dimnames = list(as.character(0:79), country_nm))

for(i in seq_along(l_files)) {
  file_cur <- l_files[i] # Current file
  dat_SCM[[i]] <- read_csv(file = paste0(path_dat, file_cur), col_names = F, show_col_types = F) %>% 
    as.matrix()
  colnames(dat_SCM[[i]]) <- NULL
  
  # Remove age groups 80-84 yr
  dat_SCM[[i]] <- dat_SCM[[i]][-c(81:85), -c(81:85)]
  
  # Extract empirical data 
  demog_dat <- read_csv(file = sprintf("_data/_demog/_2010/%s_country_level_age_distribution_85.csv", country_nm[i]), 
                        col_names = c("age", "pop"), 
                        col_types = "d") %>% 
    arrange(age) %>% 
    filter(age <= 79)
  stopifnot(nrow(demog_dat) == 80)
  
  # Project to model population structure 
  N_vec <- demog_dat$pop
  dat_SCM[[i]] <- Project_M(M = dat_SCM[[i]], N_tar = N_vec)
  dat_pop[, country_nm[i]] <- N_vec
}

PlotMatrix(M_in = dat_SCM[["France"]], plot_title = "")

# Next-generation matrices ------------------------------------------------
dat_NGM <- vector(mode = "list", length = length(l_files))
names(dat_NGM) <- country_nm

# Transmission parameters
q1_val <- 0.09373881
q2_val <- 0.5117505 * q1_val
q3_val <- 0.1725993 * q2_val
theta_val <- 0.9874919
gamma_val <- 365 / 15
q_vec <- c(rep(q1_val, 10), rep(q2_val, 10), rep(q3_val, 60))

for(i in seq_along(dat_NGM)) {
  nm_cur <- names(dat_NGM)[i]
  
  # Birth rate
  b_rate <- read_csv2(file = "_data/_demog/birth_rates_2010.csv", col_names = T, col_types = "cd")
  b_rate <- b_rate$birth_rate[b_rate$country == nm_cur] / 1e3
  stopifnot(length(b_rate) == 1 && is.numeric(b_rate))
  
  # Aging rates
  delta_vec <- numeric(80)
  N_vec <- dat_pop[, nm_cur]
  Ntot_val <- sum(N_vec)
  delta_vec[1] <- b_rate * Ntot_val / N_vec[1]
  delta_vec[-1] <- delta_vec[1] * N_vec[1] / N_vec[-1]
  
  print(sprintf("Country: %s, Ntot = %.1fM, b_rate=%.3f per yr", nm_cur, Ntot_val / 1e6, b_rate))
  
  # NGM 
  dat_NGM[[nm_cur]] <- compute_R0(theta = theta_val, 
                                  gamma = gamma_val, 
                                  N = N_vec, 
                                  q = q_vec, 
                                  delta = delta_vec, 
                                  Cmat = 365 * dat_SCM[[nm_cur]], 
                                  type = "SIR")
}

R0_vec <- map_dbl(.x = dat_NGM, .f = \(x) x %>% eigen() %>% pluck("values") %>% abs() %>% max())

# Plot demographic structure ----------------------------------------------
dat_pop_long <- reshape2::melt(dat_pop, value.name = "pop")
dat_pop_long <- dat_pop_long %>% 
  rename("country" = "Var2", "age" = "Var1") %>% 
  select(country, age, pop) %>% 
  group_by(country) %>% 
  mutate(pop_frac = pop / sum(pop)) %>% 
  ungroup()

pl <- ggplot(data = dat_pop_long, 
             mapping = aes(x = age, y = pop_frac, group = country)) + 
  geom_line(color = "grey", alpha = 1) + 
  labs(x = "Age (years)", y = "Relative population")
print(pl)

# Plot degree distribution ------------------------------------------------
c_names <- read.table(file = "_data/list_countries.txt") %>% pluck(1)
dat_SCM_sub <- dat_SCM[c_names]

n_ages <- ncol(dat_SCM[[1]]) # No of age groups

deg_dist <- dat_SCM_sub %>% 
  map(.f = \(x) data.frame(age = 0:(n_ages - 1), r_cont = rowSums(x))) %>% 
  bind_rows(.id = "country")

pl <- ggplot(data = deg_dist %>% filter(age <= 85), 
             mapping = aes(x = age, y = r_cont, color = country)) + 
  geom_line() + 
  scale_color_brewer(palette = "Set3") + 
  theme_classic() + 
  labs(x = "Age (years)", y = "Total contact rate (per day)", color = "Country")
print(pl)

r_Newman <- function(M_in) {
  
  # Use notations from https://link.springer.com/article/10.1007/s40314-013-0017-7 
  
  # Renormalize contact matrix 
  M_norm <- M_in / sum(M_in)
  
  a_i <- rowSums(M_norm)
  b_j <- colSums(M_norm)
  
  out <- (sum(diag(M_norm)) - sum(a_i * b_j)) / (1 - sum(a_i * b_j))
  return(out)
}

# Summary statistics
deg_dist_sum <- deg_dist %>% 
  group_by(country) %>% 
  summarise(age_mode = age[which.max(r_cont)], 
            age_mean = sum(age * r_cont) / sum(r_cont), 
            age_var = sum((age ^ 2) * r_cont) / sum(r_cont) - (age_mean ^ 2), 
            age_sd = sqrt(age_var),
            r_cont_tot = sum(r_cont) / 2, 
            mean_r_cont = mean(r_cont), 
            sd_r_cont = sd(r_cont), 
            max_r_cont = max(r_cont)
  ) %>% 
  ungroup() %>% 
  mutate(assort_r = map_dbl(.x = dat_SCM_sub, .f = \(x) r_Newman(x)))


# Plot SCMs in 16 countries -----------------------------------------------
c_names_all <- read.table(file = "_data/list_countries2.txt") %>% pluck(1)
c_names_all <- unique(c(c_names, c_names_all))

tmp <- dat_SCM[c_names_all] 
names(tmp) <- c_names_all
tmp <- lapply(tmp, reshape2:::melt)
tmp <- tmp %>%
  bind_rows(.id = "country") %>% 
  mutate(country = factor(country)) %>% 
  mutate(country = fct_recode(country, 
                               "USA" = "United_States", 
                               "UK" = "United-Kingdom", 
                               "Czechia" = "Czech")) %>% 
  mutate(country = fct_relevel(country, "USA", after = Inf))

i_vec <- list(1:8, 9:16)

for(i in seq_along(i_vec)) {
  pl <- ggplot(data = tmp %>% filter(country %in% levels(tmp$country)[i_vec[[i]]]), 
               mapping = aes(x = Var1 - 1, y = Var2 - 1, fill = value)) + 
    geom_tile() + 
    scale_fill_viridis(option = "rocket", trans = "sqrt") + 
    facet_rep_wrap(~ country, scales = "fixed", ncol = 2) + 
    theme_classic() + 
    theme(legend.position = "top", panel.grid = element_blank(), strip.background = element_blank()) + 
    labs(x = "Age (years)", y = "Age (years)", fill = "Daily contact rate")
  print(pl)
  
  ggsave(plot = pl, filename = sprintf("_figures/main/fig_SCMs_%d.pdf", i), width = 8, height = 8)
}

# Estimate age with peak contacts using a GAM -----------------------------
# c_name <- "United_States"
# 
# dd <- deg_dist %>% 
#   filter(country == c_name, age <= 30) %>% 
#   arrange(age)
# 
# # Fit GAM
# mod <- gam(formula = r_cont ~ s(age), data = dd)
# 
# # Extract samples of fitted values
# fs <- fitted_samples(model = mod, n = 500) %>% 
#   mutate(age = .row - 1)
# 
# # Plot
# pl <- ggplot(data = dd, mapping = aes(x = age, y = r_cont)) + 
#   geom_line(data = fs, 
#             mapping = aes(x = age, y = .fitted, group = .draw), color = "grey", alpha = 0.5) + 
#   geom_line() 
# print(pl)
# 
# # Estimate peak age
# age_max <- fs %>% 
#   group_by(.draw) %>% 
#   summarise(age_max = age[which.max(.fitted)]) %>% 
#   ungroup()
# print(quantile(age_max$age_max, probs = c(0.025, 0.5, 0.975)))

# Put data in matrix format for clustering (row: country) --------------------------------
dat_cur <- dat_NGM

dat_cur_mat <- dat_cur %>% 
  map(.f = \(x) as.numeric(x)) %>% 
  bind_rows() %>% 
  t()

# Determine optimal no of clusters ---------------------------------------------

# Visualize dissimilarity matrix
dist_nm <- "manhattan" # Name of distance measure
dist_mat <- get_dist(x = dat_cur_mat, method = dist_nm, stand = F)
fviz_dist(dist.obj = dist_mat, gradient = list(high = "#08306b", mid = "#6baed6", low = "#f7fbff"))

# Tried to run this, but it took too long 
# nb <- NbClust(data = dat_cur_mat,
#               diss = dist_mat,
#               distance = NULL,
#               index = 'silhouette', 
#               min.nc = 5,
#               max.nc = 15,
#               method = "ward.D2")

# For the silhouette method, both pam and hclust suggest an optimal no of 10
nb <- clValid(obj = dat_cur_mat, 
              metric = dist_nm,
              nClust = 5:16, 
              clMethods = c("hierarchical", "agnes", "diana", "fanny", "clara", "pam"), 
              validation = "internal", 
              method = "ward")
print(summary(nb))

# Other function to determine optimal no of clusters
fviz_nbclust(x = dat_cur_mat, 
             #FUNcluster = pam, 
             FUNcluster = hcut, 
             diss = dist_mat, 
             #method = "wss", 
             method = "silhouette", 
             k.max = 20, 
             nboot = 50)

# Run clustering for k = 10 -----------------------------------------------

# Hierarchical clustering
cl_hc <- hclust(d = dist_mat, method = "ward.D2")
cl_agnes <- cluster::agnes(x = dist_mat, diss = T, metric = dist_nm, stand = F, method = "ward")
cl_which <- cl_agnes
k_which <- 12

fviz_dend(x = cl_which, k = k_which, type = "phylogenic", k_colors = "npg")
fviz_dend(x = cl_which, k = k_which, type = "rectangle", k_colors = "npg", 
          rect = T, 
          rect_lty = 1, 
          main = "", 
          lower_rect = -110)

# PAM algorithm
#cl_pam <- pam(x = dist_mat, k = 5, diss = T)

# Cluster by age with contiguity constraints ------------------------------
# dat_age <- dat_SCM[["United-States"]]
# dat_age <- dat_age[26:85, 26:85]
# rownames(dat_age) <- paste0("Age ", 25:84)
# #rownames(dat_age) <- paste0("Age ", 0:84)
# d_age <- get_dist(x = dat_age, method = "manhattan", stand = F)
# fviz_dist(dist.obj = d_age)
# 
# nb <- clValid(obj = dat_age, 
#               metric = "manhattan",
#               nClust = 2:10, 
#               clMethods = c("hierarchical"), 
#               validation = "internal", 
#               method = "ward")
# print(summary(nb))
# 
# cl_age <- constr.hclust(d = d_age, method = "ward.D2", chron = T)
# cutree(tree = cl_age, k = 2)
# plot(cl_age, k = 2)


#######################################################################################################
# END
#######################################################################################################