Last updated: 2019-03-22

workflowr checks: (Click a bullet for more information)
  • R Markdown file: up-to-date

    Great! Since the R Markdown file has been committed to the Git repository, you know the exact version of the code that produced these results.

  • Environment: empty

    Great job! The global environment was empty. Objects defined in the global environment can affect the analysis in your R Markdown file in unknown ways. For reproduciblity it’s best to always run the code in an empty environment.

  • Seed: set.seed(20181103)

    The command set.seed(20181103) was run prior to running the code in the R Markdown file. Setting a seed ensures that any results that rely on randomness, e.g. subsampling or permutations, are reproducible.

  • Session information: recorded

    Great job! Recording the operating system, R version, and package versions is critical for reproducibility.

  • Repository version: b792cfe

    Great! You are using Git for version control. Tracking code development and connecting the code version to the results is critical for reproducibility. The version displayed above was the version of the Git repository at the time these results were generated.

    Note that you need to be careful to ensure that all relevant files for the analysis have been committed to Git prior to generating the results (you can use wflow_publish or wflow_git_commit). workflowr only checks the R Markdown file, but you know if there are other scripts or data files that it depends on. Below is the status of the Git repository when the results were generated:
    
    Ignored files:
        Ignored:    .DS_Store
        Ignored:    .Rhistory
        Ignored:    .Rproj.user/
        Ignored:    output/.DS_Store
        Ignored:    output/sc_comparisons/.DS_Store
    
    
    Note that any generated files, e.g. HTML, png, CSS, etc., are not included in this status report because it is ok for generated content to have uncommitted changes.
Expand here to see past versions:
    File Version Author Date Message
    Rmd b792cfe Jason Willwerscheid 2019-03-22 wflow_publish(“analysis/sc_comparisons.Rmd”)


Introduction

Here I compare various methods for fitting flash objects to single-cell data. I use the drop-seq dataset discussed in Montoro et al., which includes counts for approximately 18000 genes and 7000 cells. Here, I only retain basal cells, of which there are approximately 3000. Close to 90% of the basal counts are equal to zero.

I perform 50 trials. For each trial, I subsample 200 basal cells, then I subsample 500 genes from the set of genes that have at least 3 non-zero counts among the subsampled cells. I remove 1000 entries at random from the resulting data matrix, fit a series of flash objects, and calculate mean squared errors, using flash posterior means as fitted values and the original data as ground truth. I compare different variance structures, data transformations, scaling methods, priors, and numbers of factors.

Additionally, to paint a more granular picture of how additional factors affect mean squared error, I incrementally fit “greedy” flash objects with 1 to 60 factors using the full drop-seq dataset (not just a subsample of basal cells).

Code

Click “Code” to view the code used to run the tests.

# library(ashr) -- uncomment after negative mixture proportions bug is fixed
devtools::load_all("~/Github/ashr")
# library(flashier) -- uncomment after pushing 0.1.1 updates to master
devtools::load_all("~/Github/flashier")
library(Matrix)
library(ggplot2)


# Load data ---------------------------------------------------------------

# Only keep basal cells from the drop-seq dataset in Montoro et al. The data
#   can be downloaded here:
#     https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE103354
trachea <- read.table("./data/GSE103354_Trachea_droplet_UMIcounts.txt")
trachea <- as.matrix(trachea)
trachea <- Matrix(trachea)

cell.names <- colnames(trachea)
cell.types <- as.factor(sapply(strsplit(cell.names, "_"), `[`, 3))
basal <- trachea[, cell.types == 'Basal']


# Set test parameters -----------------------------------------------------

# Number of times to subsample data and run fits.
ntrials <- 50

# Size of subsampled data matrices.
ncells <- 200
ngenes <- 500

# Minimum number of nonzero counts needed to include a gene.
min.cts <- 3

# Proportion of entries to impute.
prop.missing <- 0.01
nmissing <- ceiling(prop.missing * ncells * ngenes)

# Number of factors to add to each fit.
K <- 5

# Verbose flag for flashier output.
verbose <- FALSE


# Set up data frames ------------------------------------------------------

create.df <- function(names) {
  names <- paste0(names, ".")
  df.names <- outer(names, c("overall", "zero", "nonzero"), FUN = paste0)
  df <- data.frame(rep(list(numeric(0)), length(df.names)))
  names(df) <- t(df.names)
  return(df)
}

# Test 1: variance structure.
var.df <- create.df(c("constant", "genewise", "fixed", "noisyA", "noisyB"))

# Test 2: data transformation.
trans.df <- create.df(c("log1p", "anscombe", "arcsin", "raw"))

# Test 3: normalization method.
norm.df <- create.df(c("none", "fitmean", "scale"))

# Test 4: prior type.
prior.df <- create.df(c("normal", "nngeneA", "nngeneB", "nncellA", "nncellB"))

# Test 5: number of factors and backfit.
nfactors.df <- create.df(c("g1", "bf1", "g2", "bf2", "g3", "bf3"))


# Create functions for populating data frames -----------------------------

preds <- function(fl, idx) {
  return(flashier:::lowrank.expand(fl$fit$EF)[idx])
}

mse <- function(true, preds, idx) {
  return(mean((true[idx] - preds[idx])^2))
}

mse.by.lvl <- function(true, preds) {
  zero.idx <- (true.vals == 0)
  return(list(overall.mse = mse(true, preds, rep(TRUE, length(preds))),
              zero.mse = mse(true, preds, zero.idx),
              nz.mse = mse(true, preds, !zero.idx)))
}

all.mse <- function(true, preds.list) {
  return(unlist(lapply(preds.list, function(preds) mse.by.lvl(true, preds))))
}


# Run tests ---------------------------------------------------------------

for (i in 1:ntrials) {
  cat("TRIAL", i, "\n")

  set.seed(i)
  rand.cells <- sample(1:ncol(basal), ncells)
  samp <- basal[, rand.cells]

  rand.genes <- sample(which(rowSums(samp > 0) >= min.cts), ngenes)
  samp <- samp[rand.genes, ]

  missing.idx <- sample(length(samp), nmissing)
  true.vals <- log1p(samp[missing.idx])
  samp[missing.idx] <- NA

  # Test 1: variance structure.
  cat("  Running variance structure tests.\n")

  fl.var0 <- flashier(log1p(samp), var.type = 0,
                      prior.type = "normal.mix",
                      greedy.Kmax = K + 1, backfit = "none",
                      verbose.lvl = 2L * verbose)

  fl.var1 <- flashier(log1p(samp), var.type = 1,
                      prior.type = "normal.mix",
                      greedy.Kmax = K + 1, backfit = "none",
                      verbose.lvl = 2L * verbose)

  S <- sqrt(samp) / (samp + 1)
  nz.prop <- sum(samp > 0, na.rm = TRUE) / (length(samp) - length(missing.idx))
  S[S == 0] <- sqrt(nz.prop) / (nz.prop + 1)

  fl.fixS <- flashier(log1p(samp), S = S,
                      var.type = NULL,
                      prior.type = "normal.mix",
                      greedy.Kmax = K + 1, backfit = "none",
                      verbose.lvl = 2L * verbose)

  suppressMessages({
    fl.noisyA <- flashier(log1p(samp), S = S,
                          var.type = 0,
                          prior.type = "normal.mix",
                          greedy.Kmax = K + 1, backfit = "none",
                          verbose.lvl = 2L * verbose)
  })

  suppressMessages({
    fl.noisyB <- flashier(log1p(samp), S = sqrt(samp) / (samp + 1),
                          var.type = 0,
                          prior.type = "normal.mix",
                          greedy.Kmax = K + 1, backfit = "none",
                          verbose.lvl = 2L * verbose)
  })

  var.df[i, ] <- all.mse(true.vals,
                         list(preds(fl.var0, missing.idx),
                              preds(fl.var1, missing.idx),
                              preds(fl.fixS, missing.idx),
                              preds(fl.noisyA, missing.idx),
                              preds(fl.noisyB, missing.idx)))

  # Test 2: data transformation.
  cat("  Running data transformation tests.\n")

  fl.log1p <- fl.var0

  fl.ans <- flashier(sqrt(samp + 0.375), var.type = 0,
                     prior.type = "normal.mix",
                     greedy.Kmax = K + 1, backfit = "none",
                     verbose.lvl = 2L * verbose)
  ans.preds <- log1p(preds(fl.ans, missing.idx)^2 - 0.375)

  cell.sums <- colSums(samp, na.rm = TRUE)
  props <- samp / rep(cell.sums, each = nrow(samp))
  fl.arcsin <- flashier(asin(sqrt(props)), var.type = 0,
                        prior.type = "normal.mix",
                        greedy.Kmax = K, backfit = "none",
                        verbose.lvl = 2L * verbose)
  missing.cols <- col(samp)[missing.idx]
  arcsin.preds <- log1p(sin(preds(fl.arcsin, missing.idx))^2
                        * cell.sums[missing.cols])

  fl.raw <- flashier(props, var.type = 0,
                     prior.type = "normal.mix",
                     greedy.Kmax = K,
                     backfit.after = 2, final.backfit = FALSE,
                     verbose.lvl = 2L * verbose)
  raw.preds <- log1p(preds(fl.raw, missing.idx) * cell.sums[missing.cols])
  raw.preds[is.nan(raw.preds)] <- 0

  trans.df[i, ] <- all.mse(true.vals,
                           list(preds(fl.log1p, missing.idx),
                                ans.preds,
                                arcsin.preds,
                                raw.preds))

  # Test 3: scaling method.
  cat("  Running scaling tests.\n")

  fl.none <- fl.var0

  fl.ones <- flashier(log1p(samp), var.type = 0,
                      prior.type = "normal.mix",
                      fixed.factors = c(ones.factor(2), ones.factor(1)),
                      greedy.Kmax = K,
                      backfit.after = 2, final.backfit = FALSE,
                      verbose.lvl = 2L * verbose)

  scaled.samp <- samp * median(cell.sums) / rep(cell.sums, each = nrow(samp))
  fl.scale <- flashier(log1p(scaled.samp), var.type = 0,
                       prior.type = "normal.mix",
                       fixed.factors = ones.factor(2),
                       greedy.Kmax = K, backfit = "none",
                       verbose.lvl = 2L * verbose)
  scale.preds <- log1p((exp(preds(fl.scale, missing.idx)) - 1)
                       * cell.sums[missing.cols] / median(cell.sums))

  norm.df[i, ] <- all.mse(true.vals,
                          list(preds(fl.none, missing.idx),
                               preds(fl.ones, missing.idx),
                               scale.preds))

  # Test 4: prior type.
  cat("  Running prior type tests.\n")

  fl.normalmix <- fl.var0

  fl.nngenes <- flashier(log1p(samp), var.type = 0,
                         prior.type = c("nonnegative", "normal.mix"),
                         greedy.Kmax = K + 1, backfit = "none",
                         verbose.lvl = 2L * verbose)

  fl.nngenes.pm <- flashier(log1p(samp), var.type = 0,
                            prior.type = c("nonnegative", "normal.mix"),
                            ash.param = list(method = "fdr"),
                            greedy.Kmax = K + 1, backfit = "none",
                            verbose.lvl = 2L * verbose)

  fl.nncells <- flashier(log1p(samp), var.type = 0,
                         prior.type = c("normal.mix", "nonnegative"),
                         greedy.Kmax = K + 1, backfit = "none",
                         verbose.lvl = 2L * verbose)

  fl.nncells.pm <- flashier(log1p(samp), var.type = 0,
                            prior.type = c("normal.mix", "nonnegative"),
                            ash.param = list(method = "fdr"),
                            greedy.Kmax = K + 1, backfit = "none",
                            verbose.lvl = 2L * verbose)

  prior.df[i, ] <- all.mse(true.vals,
                           list(preds(fl.normalmix, missing.idx),
                                preds(fl.nngenes, missing.idx),
                                preds(fl.nngenes.pm, missing.idx),
                                preds(fl.nncells, missing.idx),
                                preds(fl.nncells.pm, missing.idx)))

  # Test 5: number of factors and backfit.
  cat("  Running backfitting tests.\n")

  fl.g <- fl.var0

  fl.b <- flashier(flash.init = fl.g, backfit = "only",
                   backfit.reltol = 10,
                   verbose.lvl = 3L * verbose)

  # Add K more factors, then K more.
  fl.g2 <- flashier(flash.init = fl.g,
                    greedy.Kmax = K, backfit = "none",
                    verbose.lvl = 2L * verbose)
  fl.b2 <- flashier(flash.init = fl.b,
                    greedy.Kmax = K, backfit = "final",
                    backfit.reltol = 10,
                    verbose.lvl = 3L * verbose)
  fl.g3 <- flashier(flash.init = fl.g2,
                    greedy.Kmax = K, backfit = "none",
                    verbose.lvl = 2L * verbose)
  fl.b3 <- flashier(flash.init = fl.b2,
                    greedy.Kmax = K, backfit = "final",
                    backfit.reltol = 10,
                    verbose.lvl = 3L * verbose)

  nfactors.df[i, ] <- all.mse(true.vals,
                              list(preds(fl.g, missing.idx),
                                   preds(fl.b, missing.idx),
                                   preds(fl.g2, missing.idx),
                                   preds(fl.b2, missing.idx),
                                   preds(fl.g3, missing.idx),
                                   preds(fl.b3, missing.idx)))
}

# Test 6: incremental addition of factors

set.seed(666)
data <- log1p(trachea[rowSums(trachea > 0) > 2, ])
missing.idx <- sample(1:length(data), ceiling(prop.missing * length(data)))
true.vals <- data[missing.idx]
data[missing.idx] <- NA

Kmax <- 60

fl <- flashier(data, var.type = 0,
               prior.type = "normal.mix",
               fixed.factors = c(ones.factor(1), ones.factor(2)),
               greedy.Kmax = 1, backfit = "none",
               final.nullchk = FALSE,
               verbose.lvl = 3)
mse.vec <- mse(true.vals, preds(fl, missing.idx), 1:length(true.vals))

for (k in 2:Kmax) {
  fl <- flashier(flash.init = fl,
                 greedy.Kmax = 1, backfit = "none",
                 final.nullchk = FALSE,
                 verbose.lvl = 3)
  mse.vec <- c(mse.vec,
               mse(true.vals, preds(fl, missing.idx), 1:length(true.vals)))
}

mse.diff <- c(NA, mse.vec[2:length(mse.vec)] - mse.vec[1:(length(mse.vec) - 1)])
mse.df <- data.frame(k = 1:Kmax, mse = mse.vec, mse.diff = mse.diff)


# Save results ------------------------------------------------------------

all.res <- list(var.df = var.df,
                trans.df = trans.df,
                norm.df = norm.df,
                prior.df = prior.df,
                nfactors.df = nfactors.df,
                mse.df = mse.df)
saveRDS(all.res, "./output/sc_comparisons/allres.rds")
allres <- readRDS("./output/sc_comparisons/allres.rds")

Introduction to plots

For each comparison, I choose a “reference” method and plot the difference in mean squared error for each other method over all 50 trials. A positive value indicates worse performance. Importantly, mean squared error is calculated on the log1p scale in all cases.

Note that the violin plots all have the same y-axis limits (from -0.025 to +0.1). This is to underscore the relative importance of each aspect. For example, the scrunched-up appearance of the scaling methods plot is intentional: as it turns out, the scaling method used does not affect the mean squared error nearly as much as, say, the variance structure. In several cases, there are a few outlying fits with very large MSEs. When this occurs, I display all trials in addition to displaying a “zoom” view with y-axis limits from -0.025 to +0.1.

The red horizontal lines indicate the median and the 10% and 90% quantiles (unfortunately, ggplot2 ignores any outlying data points when calculating these quantities).

Variance structures

The most commonly used transformations of single-cell data are log transforms with pseudocounts. Throughout these tests, my default procedure is to transform the raw counts using the log1p transform: \[ Y_{ij} = \log \left( X_{ij} + 1 \right) \]

Thus, the model that flashier fits is \[ \log \left( X_{ij} + 1 \right) = \sum_k L_{ik} F_{jk} + E_{ij}, \] with \[E_{ij} \sim N(0, \sigma_{ij}^2)\]

If one disregards sampling error, one can simply fit a constant variance structure \(\sigma_{ij}^2 = \sigma^2\) or a gene-wise variance structure \(\sigma_{ij}^2 = \sigma_j^2\), with \(\sigma^2\) (viz. \(\sigma_j^2\)) to be estimated.

However, if (as is reasonable to assume) the data is Poisson for some true expression levels \(\lambda_{ij}\): \[ X_{ij} \sim \text{Poisson} (\lambda_{ij}), \] then sampling errors will be much different for large and small counts. Expanding around the MLE \(\hat{\lambda}_{ij} = X_{ij}\): \[ \log (\lambda_{ij} + 1) \approx \log(X_{ij} + 1) + \frac{\lambda_{ij} - X_{ij}}{X_{ij} + 1} \] so that \[ \text{Var} \left( \log (\lambda_{ij} + 1) \right) \approx \frac{\lambda_{ij}}{(X_{ij} + 1)^2} \approx \frac{X_{ij}}{(X_{ij} + 1)^2}\] Ignoring approximation error, one can thus fix the standard errors at \[ S_{ij} = \frac{\sqrt{X_{ij}}}{X_{ij} + 1} \] Unfortunately, this sets \(S_{ij} = 0\) when \(X_{ij} = 0\). To circumvent this problem, I replace zeros with a very rough estimate of the baseline Poisson noise: \[ \lambda_0 := \frac{\#\{X_{ij}: X_{ij} \ne 0\}}{\#\{X_{ij}\}} \] That is, when \(X_{ij} = 0\), I set \[ S_{ij} = \frac{\sqrt{\lambda_0}}{\lambda_0 + 1} \]

Using flashier, one can also fit a model that includes both fixed sampling errors and an approximation error to be estimated, so that \[E_{ij} \sim N(0, S_{ij}^2 + \sigma^2)\] I try two approaches. One fixes the standard errors as described above, replacing \(S_{ij} = 0\) with \(S_{ij} = \sqrt{\lambda_0} / (\lambda_0 + 1)\). The other leaves standard errors of zero as they are and hopes that the estimated \(\sigma^2\) will be able to make up the difference. In both cases, the estimated \(\sigma^2\) is constant across genes and cells (estimating a noisy gene-wise variance structure is at present far too slow for large datasets).

library(ggplot2)

get.plot.df <- function(df, ref.lvl) {
  tests <- colnames(df)
  category <- sapply(strsplit(tests, '[.]'), `[`, 1)
  exp.lvl <- sapply(strsplit(tests, '[.]'), `[`, 2)
  df <- df[, which(exp.lvl == "overall")]
  category <- category[which(exp.lvl == "overall")]
  which.ref <- which(category == ref.lvl)
  df <- df - df[, which.ref]
  df <- df[, -which.ref]
  category <- category[-which.ref]
  
  return(data.frame(value = as.vector(as.matrix(df)), 
                    category = rep(category, each = nrow(df))))
}

var.df <- get.plot.df(allres$var.df, ref.lvl = "constant")

ggplot(var.df, aes(x = category, y = value)) +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Variance structures (reference: Constant)") +
  scale_x_discrete(labels = c("Fixed", "Gene-wise", "Noisy A", "Noisy B"))

ggplot(var.df, aes(x = category, y = value)) +
  geom_violin(adjust = 1.5, 
              draw_quantiles = c(0.1, 0.5, 0.9), 
              color = "red") +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Variance structures (zoom)") +
  scale_x_discrete(labels = c("Fixed", "Gene-wise", "Noisy (no zero SEs)", "Noisy (zero SEs)")) +
  ylim(-0.025, 0.1)

Discussion

The constant variance structure clearly works best. A gene-wise variance structure can do as well (and occasionally better), but also results in a lot of very poor fits. I’ve found that a gene-wise variance structure often estimates the variances for genes with few non-zero counts to be very small, and that this can lead to severe overfitting.

Results for fixed standard errors are nearly identical to results for the noisy variance structure that uses the same standard errors, and the noisy variance structure that allows standard errors of zero does much worse. I’ve found that in both cases, \(\sigma^2\) is almost always estimated to be close to zero. In the latter case, this leads to an estimated error variance that is far too low for most counts (as mentioned, approximately 90% of counts are equal to zero).

Data transformations

Another way to handle unequal sampling errors is to use a variance-stabilizing transformation. I compare the log1p transform to the Anscombe transform (which stabilizes variance for Poisson data): \[ Y_{ij} = \sqrt{X_{ij} + \frac{3}{8}},\] and the arcsine transform (which stabilizes variance for proportions): \[ Y_{ij} = \text{arcsin}\left(\sqrt{\frac{X_{ij}}{\sum_j X_{ij}}}\right)\] In both cases, the magnitude of the sampling errors should be similar across all entries, so it should suffice to fit a constant variance structure.

Note that almost all gene proportions are small, and that the arcsine function is approximately linear for small \(x\). Thus the arcsine transformation is not much different from a square-root transformation of the proportions, which in turn might not be much different from the untransformed proportions. For purposes of comparison, then, I also fit a flash object to the untransformed proportions: \[ Y_{ij} = \frac{X_{ij}}{\sum_j X_{ij}} \]

In all cases, I calculate mean-squared error on the log1p scale. While it is true that this might bias the results in favor of the log1p transform, it can be justified as the relative error in the fitted counts (adding a pseudocount to avoid division by zero).

For the log1p and Anscombe transforms, I fit an additional rank-one “mean” factor (which simply amounts to fitting 6 factors instead of 5). Since the other transforms fit scaled data, no additional mean factor needs to be fitted.

trans.df <- get.plot.df(allres$trans.df, ref.lvl = "log1p")

ggplot(trans.df, aes(x = category, y = value)) +
  geom_violin(adjust = 1.5, 
              draw_quantiles = c(0.1, 0.5, 0.9), 
              color = "red") +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Data transformations (reference: log1p)") +
  scale_x_discrete(labels = c("Anscombe", 
                              "Arcsine", 
                              "Proportions (untransformed)")) +
  ylim(-0.025, 0.1)

Discussion

The log1p and Anscombe transforms imply very different assumptions about the model: with the log1p transform, factors are (roughly) multiplicative, whereas they are closer to additive under the Anscombe transform. Nonetheless, they do surprisingly similarly in terms of mean squared error.

The arcsine transform does worse than either of the two, and (unsurprisingly) fitting the untransformed matrix of proportions does poorly (but maybe not as poorly as expected).

Given these results, I prefer the log1p transform since it allows for a simple interpretation of factors as multiplicative effects. Further, it is possible to fit scaling factors so that loadings are comparable among one another (as discussed in the following section). Since factors are not multiplicative under the Anscombe transform, one cannot directly fit scaling factors.

Scaling methods

Ideally, one would like factor loadings to be comparable across genes even though mean expression can vary by several orders of magnitude. Mean expression varies much less across cells, but one should also account for differences in cell size. One way to achieve both of these goals is to add fixed mean factors to perform row- and column-specific scaling. Specifically, I fit a fixed row vector of all ones with column loadings \(c_j\) to be estimated and a fixed column vector of all ones with row loadings \(r_i\) to be estimated. This is approximately equivalent to estimating separate scaling factors for the rows and columns of the count data: \[ X_{ij} + 1 = e^{r_i}e^{c_j} \]

Compare to the case where FLASH estimates a single rank-one factor with row loadings \(r_i\) and column loadings \(c_j\): \[ X_{ij} + 1 = e^{r_i c_j} \] Here, the scaling factors are not independent. Do note, however, that when \(r_i\) and \(c_j\) are both small, \[ X_{ij} = e^{r_i c_j} - 1 \approx r_i c_j, \] so fitting a single factor might actually work well for small counts.

A third possible method scales the cells in advance (that is, before the log1p transform). Letting \(R_i\) be the total count for cell \(i\), I scale each cell by the factor \[ \frac{\text{median}(R_i)}{R_i} \] so that, in particular, each scaled cell has the same total count. (A mean factor for genes still needs to be fit so that gene loadings are comparable.) Although this method is conceptually simpler, it risks over-relying on a few genes with large counts. In principle, a bi-scaling method like the two discussed above should be more accurate.

norm.df <- get.plot.df(allres$norm.df, ref.lvl = "none")

ggplot(norm.df, aes(x = category, y = value)) +
  geom_violin(adjust = 1.5, 
              draw_quantiles = c(0.1, 0.5, 0.9), 
              color = "red") +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Scaling methods (reference: Rank-1 mean factor)") +
  scale_x_discrete(labels = c("Fixed ones vectors", 
                              "Pre-scaled cells")) +
  ylim(-0.025, 0.1)

Discussion

All methods do about equally well. Since it is difficult to understand exactly what the rank-one mean factor is fitting, I prefer using fixed ones vectors or pre-scaling cells when doing an in-depth analysis of a particular dataset. But the rank-one approach is easier to implement (and a bit faster), so I adopt it as the default scaling method in all other tests throughout this analysis.

Priors

For simplicity, I use normal-mixture priors as the default priors throughout this analysis, but using nonnegative priors for either genes or cells can enhance interpretability. The former yields sets of genes that co-express in the same direction, with cell loadings indicating whether expression levels for the gene set are above or below the mean. The latter can work better for clustering cells, since it yields sets of cells with one set of genes that is overexpressed and a second set that is underexpressed.

In both cases, I test two ashr parameter settings. method = "fdr" includes a point mass at zero in the prior, whereas method = "shrink" includes small mixture components but no point mass. In general, fdr is better for false discovery rate control while shrink tends to be slightly more accurate.

Since the normal-mixture prior is in principle the most flexible of all of these priors, I don’t expect the other priors to improve on the mean squared error. The primary reason for choosing to put a nonnegative prior along one of the two dimensions is interpretability, not accuracy.

prior.df <- get.plot.df(allres$prior.df, ref.lvl = "normal")

ggplot(prior.df, aes(x = category, y = value)) +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Priors (reference: Normal mixture)") +
  scale_x_discrete(labels = c("NN cells (shrink)",
                              "NN cells (fdr)",
                              "NN genes (shrink)",
                              "NN genes (fdr)"))

ggplot(prior.df, aes(x = category, y = value)) +
  geom_violin(adjust = 1.5, 
              draw_quantiles = c(0.1, 0.5, 0.9), 
              color = "red") +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Priors (zoom)") +
  scale_x_discrete(labels = c("NN cells (shrink)",
                              "NN cells (fdr)",
                              "NN genes (shrink)",
                              "NN genes (fdr)")) +
  ylim(-0.025, 0.1)

Discussion

In each case, there is at least one outlying fit that has very poor mean squared error. This problem is much worse when the nonnegative prior is on genes rather than cells. There doesn’t seem to be any consistent difference between method = "fdr" and method = "shrink".

If the outlying fits are ignored, each type of prior does approximately as well as normal-mixture priors. This is fairly surprising since 5 normal-mixture factors is in some sense equivalent to 10 semi-nonnegative factors. I’m not sure whether most of the useful information in a given factor is in one direction, or whether 5 normal-mixture factors are already beginning to overfit the data (as discussed in the following section).

Number of factors and backfitting

In all previous tests, I fit 5 factors to each flash object. Here, I also fit 10 and 15 factors to see whether there’s any evidence of overfitting. Additionally, I look at how backfitting changes mean squared error.

nfactors.df <- get.plot.df(allres$nfactors.df, ref.lvl = "g1")

ggplot(nfactors.df, aes(x = category, y = value)) +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Number of factors (reference: Greedy, 5 factors)") +
  scale_x_discrete(labels = c("Backfit, 5 f.", 
                              "Backfit, 10 f.", 
                              "Backfit, 15 f.",
                              "Greedy, 10 f.", 
                              "Greedy, 15 f."))

ggplot(nfactors.df, aes(x = category, y = value)) +
  geom_violin(adjust = 1.5, 
              draw_quantiles = c(0.1, 0.5, 0.9), 
              color = "red") +
  geom_jitter(position = position_jitter(0.25)) +
  labs(x = NULL, y = "Difference in MSE",
       title = "Number of factors (zoom)") +
  scale_x_discrete(labels = c("Backfit, 5 f.", 
                              "Backfit, 10 f.", 
                              "Backfit, 15 f.",
                              "Greedy, 10 f.", 
                              "Greedy, 15 f.")) +
  ylim(-0.025, 0.1)

Discussion

The results suggest that overfitting (in the sense of predictive accuracy) occurs well before the log likelihood of the model stops improving. The ten- and fifteen-factor fits are nearly always worse than the five-factor fits, and the backfits are on average worse than the corresponding greedy fits. In surprisingly many cases, backfitting turns out to be disastrous.

To further analyze the dynamics of overfitting, I fit a single factor to the full drop-seq dataset (after removing 1% of entries at random) and calculate mean squared error, then a second factor, and so on until 60 factors have been added. Even though flashier continues to add factors throughout the process, the mean squared error no longer improves monotonically after 32 factors, and bottoms out at 40 factors.

The results imply that flashier is unable to add the “correct” number of factors for count data, and that some type of cross validation might be needed.

mse.df <- allres$mse.df

ggplot(mse.df, aes(x = k, y = mse)) + geom_point() +
  labs(x = "Number of factors", y = "Mean squared error",
       title = "Incremental greedy additions (full dataset)")

ggplot(subset(mse.df, k > 19), aes(x = k, y = mse)) + geom_point() +
  labs(x = "Number of factors", y = "Mean squared error",
       title = "Incremental greedy additions (zoom)")

Session information

sessionInfo()
R version 3.4.3 (2017-11-30)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS High Sierra 10.13.6

Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] ggplot2_3.1.0

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.0        compiler_3.4.3    pillar_1.2.1     
 [4] git2r_0.21.0      plyr_1.8.4        workflowr_1.0.1  
 [7] bindr_0.1         R.methodsS3_1.7.1 R.utils_2.6.0    
[10] tools_3.4.3       digest_0.6.18     evaluate_0.12    
[13] tibble_1.4.2      gtable_0.2.0      pkgconfig_2.0.1  
[16] rlang_0.3.0.1     yaml_2.2.0        xfun_0.4         
[19] bindrcpp_0.2      withr_2.1.2.9000  stringr_1.3.1    
[22] dplyr_0.7.4       knitr_1.21.6      rprojroot_1.3-2  
[25] grid_3.4.3        glue_1.3.0        R6_2.3.0         
[28] rmarkdown_1.11    magrittr_1.5      whisker_0.3-2    
[31] backports_1.1.2   scales_1.0.0      htmltools_0.3.6  
[34] assertthat_0.2.0  colorspace_1.3-2  labeling_0.3     
[37] stringi_1.2.4     lazyeval_0.2.1    munsell_0.5.0    
[40] R.oo_1.21.0      

This reproducible R Markdown analysis was created with workflowr 1.0.1