from anndata import AnnData import pandas as pd # Please install pandas and matplotlib before you run this example import matplotlib.pyplot as plt import matplotlib # Set the Matplotlib backend to 'Agg' matplotlib.use('Agg') import numpy as np import scipy.sparse as sp import sklearn.metrics as mpd import gc import os import glob from scipy.spatial.distance import pdist, squareform from scMEDAL.utils.utils import read_adata, min_max_scaling,save_adata,calculate_zscores # I run it with Aixa_genomap """Adapted genomap functions from https://github.com/xinglab-ai/genomap/blob/main/genomap/genomap.py and created utils for plotting""" def createInteractionMatrix(data, metric='correlation'): """ Function from genomap github: https://github.com/xinglab-ai/genomap/blob/main/genomap/genomap.py I added here because sometimes it throws nan values Returns the interaction matrix among the genes Parameters ---------- data : ndarray, shape (cellNum, geneNum) gene expression data in cell X gene format. Each row corresponds to one cell, whereas each column represents one gene metric : 'string' Metric for computing the genetic interaction Returns ------- interactMat : ndarray, shape (geneNum, geneNum) pairwise interaction matrix among genes """ interactMat=mpd.pairwise_distances(data.T,metric=metric) return interactMat def create_gene_coordinates_mapping(projMat, gene_names, num_genes=2916, rowNum=54, colNum=54): """ Applies a projection matrix to a diagnostic array of gene indices, reshapes the result into a 54×54 grid, and maps each gene to an (x,y) coordinate. Note that if 'projMat' contains fractional assignments (e.g., from Gromov-Wasserstein), rounding to int can cause collisions and potentially overwrite or miss certain gene indices. Any missing genes are reported in the console. """ # Create and transform the diagnostic matrix diagnostic_matrix = np.arange(num_genes).reshape(1, -1) transformed_indices = np.matmul(diagnostic_matrix, projMat).flatten() px = np.round(transformed_indices, 2) # Reshape into a rowNum, colNum matrix # Order ='F' returns transposed indexes # genomaps_diagnostic = np.reshape(px, (rowNum, colNum), order='F') # Reshape into a rowNum, colNum matrix genomaps_diagnostic = np.reshape(px, (rowNum, colNum), order='C') # Default is 'C' order (Not transposed) # Map genes to coordinates gene_to_coordinates = {} count = 0 for x in range(rowNum): for y in range(colNum): gene_index = int(genomaps_diagnostic[x, y]) if 0 <= gene_index < len(gene_names): gene_name = gene_names[gene_index] gene_to_coordinates[gene_name] = (x, y) count += 1 else: print(f"Index {gene_index} not assigned to a gene name") expected_indices = set(range(len(gene_names))) found_indices = set(int(genomaps_diagnostic[i, j]) for i in range(rowNum) for j in range(colNum) if 0 <= int(genomaps_diagnostic[i, j]) < len(gene_names)) missing_indices = expected_indices - found_indices print("Missing indices:", missing_indices) return gene_to_coordinates # Example usage: # projMat = np.random.rand(2916, 2916) # Example initialization; replace with actual matrix # gene_names = ['Gene1', 'Gene2', ..., 'Gene2916'] # Define a list of gene names # mapping = create_gene_name_coordinates_mapping(projMat, gene_names) def construct_genomap(data,rowNum,colNum,epsilon=0,num_iter=1000): from genomap.genomapOPT import create_space_distributions, gromov_wasserstein_adjusted_norm from genomap.genomap import createMeshDistance """ Adapted function from genomap github: https://github.com/xinglab-ai/genomap/blob/main/genomap/genomap.py Constructs 2D "genomaps" by coupling a gene-gene interaction matrix with a grid distance matrix using Gromov-Wasserstein. Note that GW transport yields fractional assignments, so forcibly rounding positions to integer grid coordinates can cause collisions (duplicate integer indices) and result in missing or overwritten gene indices. Returns the constructed genomaps I added code to avoid nan in interactions matrix Parameters ---------- data : ndarray, shape (cellNum, geneNum) gene expression data in cell X gene format. Each row corresponds to one cell, whereas each column represents one gene rowNum : int, number of rows in a genomap colNum : int, number of columns in a genomap Returns ------- genomaps : ndarray, shape (rowNum, colNum, zAxisDimension, cell number) genomaps are formed for each cell. zAxisDimension is more than 1 when 3D genomaps are created. """ sizeData=data.shape numCell=sizeData[0] numGene=sizeData[1] # distance matrix of 2D genomap grid distMat = createMeshDistance(rowNum,colNum) # gene-gene interaction matrix interactMat = createInteractionMatrix(data, metric='correlation') # I added the following line to avoid nan in interactions matrix interactMat = np.nan_to_num(interactMat) totalGridPoint=rowNum*colNum if (numGene 0: selected_cell = random.choice(cells_in_batch) cell_ids_2plot.append(selected_cell) return cell_ids_2plot def compute_cell_stats_acrossbatchrecon(genomap, cell_indexes_batch_cf, genomap_coordinates, statistic='std', n_top_genes=10, order='C',path_2_genomap='',file_name='cell_id'): """ Calculate standard deviation or variance for genomaps of a single cell across batch reconstructions. Update genomap_coordinates DataFrame. Args: genomap: 4D numpy array with genomap data. Axis = 0 indicates individual cells. cell_indexes_batch_cf: List or array of cell indexes for batch reconstructions of single cell. The same cell was reconstructed for multiple batches. genomap_coordinates: DataFrame containing gene names and pixel coordinates. statistic: 'std' for standard deviation, 'var' for variance. n_top_genes: Number of top genes to identify. order: 'C' for default coordinate order, 'F' for transposed coordinates (i.e., pixel_i and pixel_j are swapped). path_2_genomap (str): path to genomap directory. To save stats df. file_name (str): Name to save the file. Default:"cell_id". Returns: Updated genomap_coordinates DataFrame with standard deviation/variance and rank. Explanation: - Statistics (standard deviation or variance) are calculated based on the genomap data indexed by cell_indexes_batch_cf. - The cell_indexes parameter includes the cell_indexes_batch_cf (they are a supersubset) but they can also be equal. - If order is 'C', the function uses pixel_i and pixel_j as they are to index the standard deviation/variance array. - If order is 'F', the function transposes pixel_i and pixel_j when indexing the standard deviation/variance array, effectively swapping the coordinates. """ # Calculate the standard deviation or variance along axis 0 (batches) if statistic == 'std': stat_across_batches = np.std(genomap[cell_indexes_batch_cf, :, :, 0], axis=0, ddof=1) elif statistic == 'var': stat_across_batches = np.var(genomap[cell_indexes_batch_cf, :, :, 0], axis=0, ddof=1) else: raise ValueError("Statistic must be 'std' or 'var'") print(f"Shape of {statistic}_across_batches:", stat_across_batches.shape) # Add the standard deviation/variance to the genomap_coordinates DataFrame genomap_coordinates[statistic] = np.nan for idx, row in genomap_coordinates.iterrows(): pixel_i, pixel_j = int(row['pixel_i']), int(row['pixel_j']) if order == 'C': genomap_coordinates.at[idx, statistic] = stat_across_batches[pixel_i, pixel_j] elif order == 'F': genomap_coordinates.at[idx, statistic] = stat_across_batches[pixel_j, pixel_i] # Rank the pixels based on the absolute value of standard deviation/variance genomap_coordinates['Rank'] = genomap_coordinates[statistic].abs().rank(ascending=False) # Convert gene_names from bytes to strings genomap_coordinates['gene_names'] = genomap_coordinates['gene_names'].apply(lambda x: x.strip('b').strip("'")) # Add a "Top N" column with True/False genomap_coordinates['Top_N'] = False top_n_indices = genomap_coordinates.nsmallest(n_top_genes, 'Rank').index genomap_coordinates.loc[top_n_indices, 'Top_N'] = True # Save genomap_coordinates.to_csv(os.path.join(path_2_genomap, f"genomap_{n_top_genes}topvariablegenesacrossbatches_{file_name}_{statistic}.csv")) return genomap_coordinates def adjust_text_positions(x, y, threshold=0.5, offset=0.2): """ Adjust text positions to avoid overlap. Args: x: List or array of x coordinates. y: List or array of y coordinates. threshold: Minimum distance to maintain between points. offset: Distance to shift overlapping points. Returns: List of adjusted (x, y) coordinates. """ # Calculate pairwise distances points = np.array(list(zip(x, y))) dists = squareform(pdist(points)) # Adjust positions to avoid overlap adjusted_positions = [] for i, (x_i, y_i) in enumerate(points): shift_x, shift_y = 0, 0 for j, (x_j, y_j) in enumerate(points): if i != j and dists[i, j] < threshold: shift_x += offset if x_i <= x_j else -offset shift_y += offset if y_i <= y_j else -offset adjusted_positions.append((x_i + shift_x, y_i + shift_y)) return adjusted_positions def plot_cell_recon_genomap(genomap, cell_indexes, genomap_coordinates, obs, original_batch=None, n_top_genes=10, min_val=-5, max_val=10, n_cols = 3,order='C',path_2_genomap='',file_name="cell_id",remove_ticks=False): """ Plot the genomap with the top variable genes highlighted./ Args: genomap: 4D numpy array with genomap data. cell_indexes: List or array of cell indexes. genomap_coordinates: DataFrame containing gene names and pixel coordinates. obs: DataFrame containing cell metadata including the column 'recon_prefix' that indicates the name of the reconstruction for a cell id. original_batch: Identifier for the original batch, if any. Default is None. n_top_genes: Number of top genes to highlight. min_val: Minimum value for color scale. max_val: Maximum value for color scale. n_cols (int): Number of columns for the subplot. Default is 3. order: 'C' for default coordinate order, 'F' for transposed coordinates (i.e., pixel_i and pixel_j are swapped). path_2_genomap (str): path to genomap directory. To save plot. file_name (str): Name to save the file. Default:"cell_id" remove_ticks (bool,optional): Flag to remove ticks from the plot. Default=False. """ geno_slices_cell_id = genomap[cell_indexes, :, :, 0] n_images = len(cell_indexes) n_rows = (n_images + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols+1, 5 * n_rows)) # If axes is 1D, convert to 2D if n_rows == 1: axes = np.expand_dims(axes, 0) if n_cols == 1: axes = np.expand_dims(axes, 1) if genomap_coordinates is not None: top_n_coordinates = genomap_coordinates[genomap_coordinates['Top_N']] for i, (cell_index, cell_geno) in enumerate(zip(cell_indexes, geno_slices_cell_id)): ax = axes[i // n_cols, i % n_cols] im = ax.imshow(cell_geno, cmap='viridis', vmin=min_val, vmax=max_val) # Remove ticks if remove_ticks: ax.set_xticks([]) ax.set_yticks([]) # # Add gene labels if genomap_coordinates is provided if genomap_coordinates is not None: # Get coordinates if order == 'C': x, y = top_n_coordinates['pixel_i'], top_n_coordinates['pixel_j'] elif order == 'F': x, y = top_n_coordinates['pixel_j'], top_n_coordinates['pixel_i'] # Adjust text positions adjusted_positions = adjust_text_positions(x, y, threshold=3, offset=6) for (adj_x, adj_y), (pixel_i, pixel_j), gene in zip(adjusted_positions, zip(x, y), top_n_coordinates['gene_names']): ax.plot(pixel_i, pixel_j, 'o', markerfacecolor='none', markeredgecolor='red', markersize=6, markeredgewidth=2) ax.text(adj_x, adj_y, gene, color='black', ha='left', va='center', fontweight='bold', fontsize=12, clip_on=False) # print("cell_index",cell_index) # print("obs_index",obs.index) recon_prefix = obs.loc[cell_index, "recon_prefix"] ax.set_title(recon_prefix) if 'input' in recon_prefix: ax.set_title(f'{recon_prefix}\noriginal batch: {original_batch}') elif original_batch in recon_prefix: ax.set_title(f'{recon_prefix}\n(original batch)', color='red') else: ax.set_title(f'{recon_prefix}') # Hide any unused subplots for j in range(i + 1, n_rows * n_cols): fig.delaxes(axes.flatten()[j]) cbar_ax = fig.add_axes([0.94, 0.15, 0.005, 0.7]) # Adjusted for vertical color bar fig.colorbar(im, cax=cbar_ax, orientation='vertical') fig.suptitle(file_name, fontsize=16) plt.tight_layout(rect=[0, 0.03, 0.9, 0.97]) fig.savefig(os.path.join(path_2_genomap, f"genomap_{n_top_genes}topvariablegenesacrossbatches_{file_name}"))