Validation of RCTD on decomposition of simulated doublets

library(RCTD)
library(Matrix)
library(ggplot2)
library(ggpubr)
library(gridExtra)
library(reshape2)
library(readr)
library(Seurat)
source('../Plotting/figure_utils.R')

Load in RCTD results, calculate cell class identification rate

#Command used to save the data from the gather_results.R script:
#save(puck_d, iv, results, file = 'Data/SpatialRNA/Puck_Viz/results/gathered_results.RData')
#loading in that data:
refdir = '../../Data/Reference/DropVizHC'
load('../../Data/SpatialRNA/Puck_Viz/results/gathered_results.RData')
results_df <- results$results_df
metadir <- file.path(paste0('../../','Data/SpatialRNA/Puck_Viz'),"MetaData")
meta_data <- readRDS(file.path(metadir,"meta_data.RDS"))
meta_df <- meta_data$meta_df
UMI_tot <- meta_data$UMI_tot; UMI_list <- meta_data$UMI_list

get_class_df <- function(cell_type_names, use_classes = F) {
  class_df = data.frame(cell_type_names, row.names = cell_type_names)
  colnames(class_df)[1] = "class"
  if(use_classes) {
    class_df["Bergmann","class"] = "Astrocytes"
    class_df["Fibroblast","class"] = "Endothelial"
    class_df["MLI2","class"] = "MLI1"
    class_df["Macrophages","class"] = "Microglia"
    class_df["Polydendrocytes","class"] = "Oligodendrocytes"
  }
  return(class_df)
}
cell_type_names <- iv$cell_type_info[[2]]
class_df <- get_class_df(cell_type_names, use_classes = T)

resultsdir = file.path(iv$slideseqdir, "results")

#next make the confusion matrix
true_types = unlist(list(meta_df[meta_df$first_UMI == 0,"second_type"], meta_df[meta_df$first_UMI == UMI_tot,"first_type"]))
pred_types = unlist(list(results_df[meta_df$first_UMI == 0, "first_type"], results_df[meta_df$first_UMI == UMI_tot, "first_type"]))
conf_mat <- caret::confusionMatrix(pred_types,factor(true_types,levels = iv$cell_type_info[[2]]))
conf_mat_RCTD <- conf_mat$table
common_cell_types = c("Astrocytes", "Bergmann", "Endothelial", "Fibroblast", "Golgi", "Granule", "MLI1", "MLI2", "Oligodendrocytes", "Polydendrocytes", "Purkinje", "UBCs")
square_results <- plot_doublet_identification_certain(meta_data, common_cell_types,'./', meta_df, results_df, class_df, UMI_list)

Compute and plot doublet classification rate

N_UMI_cond <- (length(UMI_list) + 1)/2
spot_class_df <- as.data.frame(Matrix(0,nrow = N_UMI_cond, ncol = 2))
rownames(spot_class_df) <- UMI_list[1:N_UMI_cond]; colnames(spot_class_df) <- c('singlets','doublets')
barcodes <- rownames(results_df)
for(nUMI in UMI_list) {
  vec = table(results_df[barcodes[meta_df$first_UMI == nUMI],]$spot_class)
  spot_class_df[as.character(nUMI), 'singlets'] = vec["singlet"]
  spot_class_df[as.character(nUMI), 'doublets'] = sum(vec) - vec["singlet"] - vec["reject"]
}
spot_class_df[1:6,] <- spot_class_df[1:6,] + spot_class_df[13:8,]
spot_class_df <- spot_class_df[1:7,]
spot_class_df <- sweep(spot_class_df, 1, rowSums(spot_class_df),'/')
spot_class_df[,"nUMI"] <- UMI_list[1:N_UMI_cond]
df <- melt(spot_class_df,  id.vars = 'nUMI', variable.name = 'series')
df$std <- sqrt(df$value*(1 - df$value)/(table(meta_df$first_UMI)*2)[1])
df$std[df$nUMI == 500] <- df$std[df$nUMI == 500]*sqrt(2)
doublet_df <- df
my_pal = pals::coolwarm(20)
doublet_df$prop <- doublet_df$nUMI / 1000
plot_df = doublet_df[doublet_df$series == "doublets",]
p1 <- ggplot(plot_df, aes(prop,value)) + geom_line()  + 
  geom_errorbar(aes(ymin=value-1.96*std, ymax=value+1.96*std), width=.02,position=position_dodge(.001)) + theme_classic() + 
  geom_hline(yintercept=0, linetype="dashed", color = "grey", size=0.5) + geom_hline(yintercept=1, linetype="dashed", color = "grey", size=0.5) + xlab("UMI Proportion of Minority Cell Type") + ylab("Doublet Classification Rate") + scale_color_manual(values=c(my_pal[1], my_pal[20]),labels = c("Doublet"), name = "Class") + scale_y_continuous(breaks = c(0,0.5,1), limits = c(0,1))+ scale_x_continuous(breaks = c(0,0.25,0.5), limits = c(-.03,0.53))


ggarrange(p1)

Plot cell class identification rate

load('../../data/SpatialRNA/Puck_Viz/results/gathered_results.RData')
class_df <- data.frame(iv$cell_type_info[[2]], row.names = iv$cell_type_info[[2]]); colnames(class_df)[1] = "class"
class_df["Bergmann","class"] = "Astrocytes"
class_df["Fibroblast","class"] = "Endothelial"
class_df["MLI2","class"] = "MLI1"
class_df["Macrophages","class"] = "Microglia"
class_df["Polydendrocytes","class"] = "Oligodendrocytes"
results_df <- results$results_df
metadir <- "../../data/SpatialRNA/Puck_Viz/MetaData"
meta_data <- readRDS(file.path(metadir,"meta_data.RDS"))
meta_df <- meta_data$meta_df
first_pred <- class_df[results_df$first_type,] == class_df[as.character(meta_df$first_type),] |
  class_df[results_df$first_type,] == class_df[as.character(meta_df$second_type),]
first_conf <- results_df$spot_class != 'reject'

second_pred <- class_df[results_df$second_type,] == class_df[as.character(meta_df$first_type),] |
  class_df[results_df$second_type,] == class_df[as.character(meta_df$second_type),]
second_conf <- results_df$spot_class == 'doublet_certain'

res_df <- data.frame(second_pred, meta_df$first_type, meta_df$second_type)[second_conf & 250 <= meta_df$first_UMI & meta_df$first_UMI <= 750,]
sec_suc <- aggregate(. ~ meta_df.first_type + meta_df.second_type, res_df, sum)
sec_tot <- aggregate(. ~ meta_df.first_type + meta_df.second_type, res_df, length)
res_df <- data.frame(first_pred, meta_df$first_type, meta_df$second_type)[first_conf & 250 <= meta_df$first_UMI & meta_df$first_UMI <= 750,]
fir_suc <- aggregate(. ~ meta_df.first_type + meta_df.second_type, res_df, sum)
fir_tot <- aggregate(. ~ meta_df.first_type + meta_df.second_type, res_df, length)
sec_suc$accuracy_mat <- (fir_suc$first_pred + sec_suc$second_pred) / (fir_tot$first_pred + sec_tot$second_pred)

colnames(sec_suc) = c('Prediction','Reference','skip','value')
sec_suc_tot <- rbind(sec_suc,sec_suc)
sec_suc_tot$Prediction <- as.character(sec_suc_tot$Prediction)
sec_suc_tot$Reference <- as.character(sec_suc_tot$Reference)
sec_suc_tot$Prediction[67:132] <- sec_suc_tot$Reference[1:66]
sec_suc_tot$Reference[67:132] <- sec_suc_tot$Prediction[1:66]
square_results <- dcast(sec_suc_tot, formula = Reference ~ Prediction)
print(mean(sec_suc_tot$value[1:66]))
print(sd(sec_suc_tot$value[1:66]))

rownames(square_results) <- square_results$Reference
square_results$Reference <- NULL
square_results[1,] <-colMeans(square_results[1:2,],na.rm=T)
square_results[3,] <- colMeans(square_results[3:4,],na.rm=T)
square_results[7,] <- colMeans(square_results[7:8,],na.rm=T)
square_results[9,] <- colMeans(square_results[7:10,],na.rm=T)
square_results <- square_results[c(1,3,5,6,7,9,11,12),]
rownames(square_results)[1] <- "Astrocytes/Bergmann"
rownames(square_results)[2] <- "Endo./Fibroblast"
rownames(square_results)[5] <- "MLI"
rownames(square_results)[6] <- "Oligoden./Polyden."
data <- melt(as.matrix(square_results))
colnames(data) = c('Prediction','Reference','value')
mean(data$value,na.rm=T)
#p2 <- ggplot(data, aes(Prediction, Reference, fill= value)) +  geom_tile() +theme_classic() +scale_fill_gradientn(colors = pals::brewer.reds(20)[1:20], limits= c(0,1),name = "Identification Rate")+ theme(axis.text.x = element_text(angle = 45, hjust = 1)) + xlab('First Cell Type')+ ylab('Second Cell Type')
diverse_pal <- pals::kelly(13)[2:(13)]
p2 <- ggplot(data, aes(Prediction, value, group=Reference, color = Reference))  +  geom_jitter(width = 0.3, size = 1) +theme_classic() + theme(axis.text.x = element_text(angle = 45, hjust = 1)) + xlab('Cell Class 1')+ ylab('Cell Class Identification Rate') + scale_color_manual(values = diverse_pal, name = "Cell Type 2") + theme(axis.text=element_text(size=8)) + theme(axis.title=element_text(size=10))+ theme(legend.text=element_text(size=8),legend.spacing.x = unit(-0.1, 'cm'),legend.spacing.y = unit(-0.1, 'cm')) + guides(guide_legend(nrow=6,byrow=TRUE)) + scale_y_continuous(breaks = c(0,0.5,1), limits = c(-.0001,1.0001))+ theme(legend.position="top")
ggarrange(p2)

### Load in RCTD results for estimating doublet cell type proportion

iv <- init_RCTD(load_info_renorm = T) #initial variables
cell_type_names = iv$cell_type_info[[2]]
metadir <- file.path(iv$slideseq,"MetaData")
meta_df <- meta_data$meta_df
barcodes <- rownames(meta_df)
N = length(barcodes)
first_beads <- readRDS(file.path(metadir,"firstbeads.RDS"))
second_beads <- readRDS(file.path(metadir,"secbeads.RDS"))
expect1 = Matrix(0, nrow = N, ncol = length(iv$gene_list))
expect2 = Matrix(0, nrow = N, ncol = length(iv$gene_list))
var = Matrix(0, nrow = N, ncol = length(iv$gene_list))
weights_doublet = Matrix(0, nrow = N, ncol = 2)
index = 1
expect1_list <- list(); expect2_list <- list(); var_list <- list();
for (fold_index in 1:iv$n_puck_folds) {
  print(fold_index)
  results <- readRDS(paste0(iv$slideseqdir,"/DecomposeResults/results",fold_index,".RDS"))
  weights_doublet[index:(index+length(results)-1),] <- do.call(rbind,lapply(results,function(x) x$weights))
  exp_list <- lapply(results,function(x) x$decompose_results$expect_1)
  expect1_list[[fold_index]] <- matrix(unlist(exp_list), ncol = length(iv$gene_list), byrow = TRUE)
  exp_list <- lapply(results,function(x) x$decompose_results$expect_2)
  expect2_list[[fold_index]] <- matrix(unlist(exp_list), ncol = length(iv$gene_list), byrow = TRUE)
  exp_list <- lapply(results,function(x) x$decompose_results$variance)
  var_list[[fold_index]] <- matrix(unlist(exp_list), ncol = length(iv$gene_list), byrow = TRUE)
  index = index + length(results)
}

expect1 <- do.call(rbind, expect1_list)
rm(expect1_list)
expect2 <- do.call(rbind, expect2_list)
rm(expect2_list)
var <- do.call(rbind, var_list)
rm(var_list)
rownames(expect1) = barcodes[1:N];
rownames(expect2) = barcodes[1:N]; rownames(var) = barcodes[1:N]
colnames(expect1) = iv$gene_list; colnames(expect2) = iv$gene_list
colnames(var) = iv$gene_list;
rownames(weights_doublet) = barcodes[1:N]; colnames(weights_doublet) = c('first_type', 'second_type')

beads = expect1 + expect2

#weight recovery plot
DropViz <- T
if(DropViz) {
  common_cell_types = c("Astrocytes", "Bergmann", "Endothelial", "Fibroblast", "Golgi", "Granule", "MLI1", "MLI2", "Oligodendrocytes", "Polydendrocytes", "Purkinje", "UBCs")
} else {
  common_cell_types <- iv$cell_type_info[[2]]
}
RMSE_dat = Matrix(0, nrow = meta_data$n_cell_types, ncol = meta_data$n_cell_types)
R2_dat = Matrix(0, nrow = meta_data$n_cell_types, ncol = meta_data$n_cell_types)
DE_dat = Matrix(0, nrow = meta_data$n_cell_types, ncol = meta_data$n_cell_types)
rownames(RMSE_dat) = common_cell_types[1:meta_data$n_cell_types]
colnames(RMSE_dat) = common_cell_types[1:meta_data$n_cell_types]
rownames(R2_dat) = common_cell_types[1:meta_data$n_cell_types]
colnames(R2_dat) = common_cell_types[1:meta_data$n_cell_types]
rownames(DE_dat) = common_cell_types[1:meta_data$n_cell_types]
colnames(DE_dat) = common_cell_types[1:meta_data$n_cell_types]
plot_results <- as.list(rep(0,meta_data$n_cell_types^2))
dim(plot_results) <- c(meta_data$n_cell_types, meta_data$n_cell_types)
dimnames(plot_results) <- list(common_cell_types[1:meta_data$n_cell_types],common_cell_types[1:meta_data$n_cell_types])
for(ind1 in 1:(meta_data$n_cell_types-1))
  for(ind2 in (ind1+1):(meta_data$n_cell_types)) {
    type1 = common_cell_types[ind1]; type2 = common_cell_types[ind2]
    print(paste(type1,type2))
    plot_results[[type1,type2]]<- get_decompose_plots(meta_df, type1, type2, weights_doublet, meta_data, iv, expect1, expect2, var, first_beads, second_beads, beads)
    R2_dat[type1,type2] <- plot_results[[type1,type2]]$R2
    RMSE_dat[type1,type2] <- plot_results[[type1,type2]]$RMSE
    R2_dat[type2,type1] <- plot_results[[type1,type2]]$R2
    RMSE_dat[type2,type1] <- plot_results[[type1,type2]]$RMSE
    DE_dat[type1,type2] <- plot_results[[type1,type2]]$de_scores[1]
    DE_dat[type2,type1] <- plot_results[[type1,type2]]$de_scores[2]
  }
decomp_dir = file.path(iv$slideseqdir,"DecomposePlots")
if(!dir.exists(decomp_dir))
  dir.create(decomp_dir)
for(plot_title in c('bias_plot', 'err_plot','hist_plot','de_gene_plot','weight_plot', 'de_ind_plot')) {
  pdf(file.path(decomp_dir,paste0(plot_title,".pdf")))
  for(ind1 in 1:(meta_data$n_cell_types-1))
    for(ind2 in (ind1+1):(meta_data$n_cell_types)) {
      type1 = common_cell_types[ind1]; type2 = common_cell_types[ind2]
      invisible(print(plot_results[[type1,type2]][[plot_title]]))
    }
  dev.off()
}

plot_heat_map(as.matrix(R2_dat), normalize = F, file_loc = file.path(decomp_dir,"geneR2.png"), save.file=T)
plot_heat_map(as.matrix(RMSE_dat), normalize = F, file_loc = file.path(decomp_dir,"RMSE_weights.png"), save.file=T)
RMSE_dat <- as.matrix(RMSE_dat)
DE_dat <- as.matrix(DE_dat)
plot_df_weight <- plot_results[["Bergmann","Purkinje"]]$plot_df_weight
plot_df_de <- plot_results[["Bergmann","Purkinje"]]$de_plot_df
save(RMSE_dat,DE_dat, plot_df_weight, plot_df_de,file = "../Plotting/Results/decompose.RData")

Plot predicted vs true cell type proportion

load(file = "../Results/decompose.RData")
my_pal = pals::coolwarm(20)

p2 <-ggplot2::ggplot(plot_df_weight, ggplot2::aes(x=proportion, y=type1_avg, colour = "type1_avg")) +
    ggplot2::geom_line() +
    ggplot2::geom_point()+
    ggplot2::geom_line(ggplot2::aes(y=proportion,colour = "proportion")) +
    ggplot2::geom_errorbar(ggplot2::aes(ymin=type1_avg-st_dev/1.96, ymax=type1_avg+st_dev/1.96), width=.05,
                           position=ggplot2::position_dodge(0.05)) + theme_classic() + xlab('True Bergmann Proportion')+ ylab('Predicted Bergmann Proportion')+ scale_color_manual(values=c(my_pal[20], my_pal[1]),labels = c("",""), name = "") + scale_y_continuous(breaks = c(0,0.5,1), limits = c(-.01,1.01))+ scale_x_continuous(breaks = c(0,0.5,1), limits = c(-.03,1.03))+ theme(legend.position = "none")
ggarrange(p2)

Plot Root Mean Squared Error of cell type proportion

diag(RMSE_dat) <- 0 
rownames(RMSE_dat)[9] = "Oligoden."
rownames(RMSE_dat)[10] = "Polyden."
colnames(RMSE_dat)[9] = "Oligoden."
colnames(RMSE_dat)[10] = "Polyden."
data <- melt(as.matrix(RMSE_dat))
colnames(data) = c('Prediction','Reference','value')
p1 <- ggplot(data, aes(Prediction, value, group=Reference, color = Reference))   + scale_y_continuous(breaks = c(0,0.25,0.5), limits = c(-.0001,0.5)) + geom_jitter(width = 0.3,size=1) +theme_classic() + theme(axis.text.x = element_text(angle = 45, hjust = 1)) + xlab('Cell Type 1')+ ylab('Root Mean Squared Error') + scale_color_manual(values = diverse_pal, name = "Cell Type 2")+ theme(axis.text=element_text(size=8)) + theme(axis.title=element_text(size=10))  +theme(legend.text=element_text(size=8),legend.spacing.x = unit(-0.1, 'cm'),legend.spacing.y = unit(-0.1, 'cm')) + guides(guide_legend(nrow=6,byrow=TRUE)) + scale_y_continuous(breaks = c(0,0.5,1), limits = c(-.0001,1.0001))+ theme(legend.position="top")


ggarrange(p1)