Last updated: 2018-06-19

Fitting methods

The MASH fit is produced following the recommendations in the MASH vignettes (using both canonical matrices and data-driven matrices).

Two FLASH fits are produced. FLASH-OHL (for “one-hots last”) adds up to ten factors greedily, then adds a one-hot vector for each row in the data matrix, then backfits the whole thing. FLASH-OHF (for “one-hots first”) adds the one-hot vectors first (as you’ve probably already guessed), then backfits, then greedily adds up to ten factors. In the latter case, the greedily added factors are not subsequently backfit, so FLASH-OHF can be much faster than FLASH-OHL.


All simulated datasets \(Y\) are of dimension 25 x 1000. In each case, \(Y = X + E\), where \(X\) is the matrix of “true” effects and \(E\) is a matrix of \(N(0, 1)\) noise.

Null model

Here the entries of \(X\) are all zero.

MSE 0.003 0.003 0.003 0.003
95% CI cov 0.880 1.000 1.000 1.000

Model with independent effects

Now the columns \(X_{:, j}\) are either identically zero (with probability 0.8) or identically nonzero. In the latter case, the entries of the \(j\)th column of \(X\) are i.i.d. \(N(0, 1)\).

MSE 0.014 0.018 0.018 0.018
95% CI cov 0.989 0.984 0.984 0.984

Model with independent and shared effects

Again 80% of the columns of \(X\) are identically zero. But now, only half of the nonzero columns have entries that are i.i.d. \(N(0, 1)\). The other half have entries that are identical across rows, with a value that is drawn from the \(N(0, 1)\) distribution. (In other words, the covariance matrix for these columns is a matrix of all ones.)

MSE 0.016 0.019 0.021 0.020
95% CI cov 0.991 0.985 0.984 0.984

Model with independent, shared, and unique effects

This model is similar to the above two, but now only a third of the nonnull columns have independently distributed entries and a third have shared entries. The other third have a unique nonzero entry. (This corresponds, for example, to a gene that is only expressed in a single condition.) The unique effects are distributed uniformly across rows.

MSE 0.016 0.020 0.021 0.021
95% CI cov 0.988 0.983 0.983 0.983

Rank 1 FLASH model

This is the FLASH model \(X = LF\), where \(L\) is an \(n\) by \(k\) matrix and \(F\) is a \(k\) by \(p\) matrix. In this first simulation, \(k = 1\). 80% of the entries in \(F\) and 50% of the entries in \(L\) are equal to zero. The other entries are i.i.d. \(N(0, 1)\).

MSE 0.022 0.016 0.037 0.020
95% CI cov 0.979 0.955 0.962 0.955

Rank 5 FLASH model

This is the same as above with \(k = 5\) and with only 20% of the entries in \(L\) equal to zero.

MSE 0.115 0.077 0.268 0.181
95% CI cov 0.928 0.901 0.909 0.875

Rank 3 FLASH model with UV

This is similar to the above with \(k = 3\) and with 30% of the rows in \(L\) equal to zero. In addition, a dense rank-one matrix \(W\) is added to \(X\) to mimic the effects of unwanted variation. Here, \(W = UV\), with \(U\) an \(n\) by 1 vector and \(V\) a 1 by \(p\) vector, both of which have entries distributed \(N(0, 0.25)\).

MSE 0.089 0.081 0.180 0.103
95% CI cov 0.946 0.905 0.935 0.914


for simulating datasets…

## SIMULATION FUNCTIONS -------------------------------------------------

# n is number of conditions, p is number of genes

# Noise is i.i.d. N(0, 1)
get_E <- function(n, p, sd = 1) {
  matrix(rnorm(n * p, 0, sd), n, p)

# Simulate from null model ----------------------------------------------

null_sim <- function(n, p, seed = NULL) {
  Y <- get_E(n, p)
  true_Y <- matrix(0, n, p)

  list(Y = Y, true_Y = true_Y)

# Simulate from MASH model ----------------------------------------------

# Sigma is list of covariance matrices
# pi[j] is probability that effect j has covariance Sigma[[j]]
# s is sparsity (percentage of null effects)
mash_sim <- function(n, p, Sigma, pi = NULL, s = 0.8, seed = NULL) {
  if (is.null(pi)) {
    pi = rep(1, length(Sigma)) # default to uniform distribution
  assertthat::are_equal(length(pi), length(Sigma))
  for (j in length(Sigma)) {
    assertthat::are_equal(dim(Sigma[j]), c(n, n))

  pi <- pi / sum(pi) # normalize pi to sum to one
  which_sigma <- sample(1:length(pi), p, replace=TRUE, prob=pi)
  nonnull_fx <- sample(1:p, floor((1 - s)*p), replace=FALSE)

  X <- matrix(0, n, p)
  for (j in nonnull_fx) {
    X[, j] <- MASS::mvrnorm(1, rep(0, n), Sigma[[which_sigma[j]]])
  Y <- X + get_E(n, p)
  list(Y = Y, true_Y = X)

# Simulate from FLASH model ---------------------------------------------

# fs is sparsity of factors (percentage of null effects)
# fvar is variance of effects (generated from normal distribution)
# ls is sparsity of loadings
# lvar is variance of loadings
# UVvar is variance of dense rank-one matrix included to mimic something
#   like unwanted variation (set it to 0 to ignore it)
flash_sim <- function(n, p, k, fs, fvar, ls, lvar, UVvar = 0, seed = NULL) {

  nonnull_ll <- matrix(sample(c(0, 1), n*k, TRUE, c(ls, 1 - ls)), n, k)
  LL <- nonnull_ll * matrix(rnorm(n*k, 0, sqrt(lvar)), n, k)

  nonnull_ff <- matrix(sample(c(0, 1), k*p, TRUE, c(fs, 1 - fs)), k, p)
  FF <- nonnull_ff * matrix(rnorm(k*p, 0, sqrt(fvar)), k, p)

  X <- LL %*% FF
  Y <- X + get_E(n, p)
  # add unwanted variation
  Y <- Y + outer(rnorm(n, 0, sqrt(UVvar)), rnorm(p, 0, sqrt(UVvar)))
  list(Y = Y, true_Y = X)

## SIMULATIONS ----------------------------------------------------------

# Functions to generate seven types of datasets. One is null; three are
# from the MASH model; three are from the FLASH model.

sim_fns <- function(n, p, s, mashvar, fvar, lvar, UVvar) {

  # 1. Everything is null
  sim_null <- function(){ null_sim(n, p) }

  Sigma <- list()
  Sigma[[1]] <- diag(rep(mashvar, n))
  # 2. Effects are independent across conditions
  sim_ind <- function(){ mash_sim(n, p, Sigma) }

  Sigma[[2]] <- matrix(mashvar, n, n)
  # 3. Effects are either independent or shared
  sim_indsh <- function(){ mash_sim(n, p, Sigma) }

  for (j in 1:n) {
    Sigma[[2 + j]] <- matrix(0, n, n)
    Sigma[[2 + j]][j, j] <- mashvar
  pi <- c(n, n, rep(1, n))
  # 4. Effects are independent, shared, or unique to a single condition
  sim_mash <- function(){ mash_sim(n, p, Sigma) }

  # 5. Rank one model
  sim_rank1 <- function(){ flash_sim(n, p, 1, s, fvar, 0.5, lvar) }

  # 6. Rank 5 model
  sim_rank5 <- function(){ flash_sim(n, p, 5, s, fvar, 0.2, lvar) }

  # 7. Rank 3 model with unwanted variation
  sim_UV <- function(){ flash_sim(n, p, 3, s, fvar, 0.3, lvar, UVvar) }

  c(sim_null, sim_ind, sim_indsh, sim_mash, sim_rank1, sim_rank5, sim_UV)

sim_names <- c("Null simulation", "All independent effects",
               "Independent and shared", "Independent, shared, and unique",
               "Rank 1 FLASH model", "Rank 5 FLASH model",
               "Rank 3 FLASH with UV")

…for fitting MASH and FLASH objects…

# Fit using FLASH -------------------------------------------------------
fit_flash <- function(Y, Kmax, method) {
  n <- nrow(Y)
  data <- flash_set_data(Y, S = 1)
  timing <- list()

  t0 <- Sys.time()
  if (method %in% c("OHF", "OHFplus")) {
    fl <- flash_add_fixed_l(data, diag(rep(1, n)))
    fl <- flash_backfit(data, fl, nullcheck = F, var_type = "zero")
    t1 <- Sys.time()
    timing$backfit <- t1 - t0
    fl <- flash_add_greedy(data, Kmax, fl, var_type = "zero")
    timing$greedy <- Sys.time() - t1
    if (method == "OHFplus") {
      t2 <- Sys.time()
      fl <- flash_backfit(data, fl, nullcheck = F, var_type = "zero")
      timing$backfit <- timing$backfit + (Sys.time() - t2)
  } else {
    fl <- flash_add_greedy(data, Kmax, var_type = "zero")
    t1 <- Sys.time()
    timing$greedy <- t1 - t0
    fl <- flash_add_fixed_l(data, diag(rep(1, n)), fl)
    fl <- flash_backfit(data, fl, nullcheck = F, var_type = "zero")
    timing$backfit <- Sys.time() - t1

  timing$total <- Reduce(`+`, timing)

  list(fl = fl, timing = timing)

# Fit using MASH -------------------------------------------------------
fit_mash <- function(Y, ed=T) {
  data <- mash_set_data(t(Y))
  timing <- list()

  # time to create canonical matrices is negligible
  U = cov_canonical(data)

  if (ed) {
    t0 <- Sys.time()
    m.1by1 <- mash_1by1(data)
    strong <- get_significant_results(m.1by1, 0.05)
    U.pca <- cov_pca(data, 5, strong)
    U.ed <- cov_ed(data, U.pca, strong)
    U <- c(U, U.ed)
    timing$ed <- Sys.time() - t0

  t0 <- Sys.time()
  m <- mash(data, U)
  timing$mash <- Sys.time() - t0

  timing$total <- Reduce(`+`, timing)

  list(m = m, timing = timing)

…for evaluating performance…

# Evaluate methods based on MSE, CI coverage, and TPR vs. FPR -----------

flash_diagnostics <- function(fl, Y, true_Y, nsamp) {
  MSE <- flash_mse(fl, true_Y)

  # Sample from FLASH fit to estimate CI coverage and TPR vs. FPR
  fl_sampler <- flash_lf_sampler(Y, fl, ebnm_fn=ebnm_pn, fixed="loadings")
  fl_samp <- fl_sampler(nsamp)

  CI <- flash_ci(fl_samp, true_Y)
  ROC <- flash_roc(fl, fl_samp, true_Y)

  list(MSE = MSE, CI = CI, TP = ROC$TP, FP = ROC$FP,
       n_nulls = ROC$n_nulls, n_nonnulls = ROC$n_nonnulls)

mash_diagnostics <- function(m, true_Y) {
  MSE <- mash_mse(m, true_Y)
  CI <- mash_ci(m, true_Y)
  ROC <- mash_roc(m, true_Y)

  list(MSE = MSE, CI = CI, TP = ROC$TP, FP = ROC$FP,
       n_nulls = ROC$n_nulls, n_nonnulls = ROC$n_nonnulls)

# MSE of posterior means (FLASH) ----------------------------------------
flash_mse <- function(fl, true_Y) {
  mean((flash_get_lf(fl) - true_Y)^2)

# MSE for MASH ----------------------------------------------------------
mash_mse <- function(m, true_Y) {
  mean((get_pm(m) - t(true_Y))^2)

# 95% CI coverage for FLASH ---------------------------------------------
flash_ci <- function(fl_samp, true_Y) {
  n <- nrow(true_Y)
  p <- ncol(true_Y)
  nsamp <- length(fl_samp)

  flat_samp <- matrix(0, nrow=n*p, ncol=nsamp)
  for (i in 1:nsamp) {
    flat_samp[, i] <- as.vector(fl_samp[[i]])
  CI <- t(apply(flat_samp, 1, function(x) {quantile(x, c(0.025, 0.975))}))
  mean((as.vector(true_Y) >= CI[, 1]) & (as.vector(true_Y) <= CI[, 2]))

# 95% CI coverage for MASH ----------------------------------------------
mash_ci <- function(m, true_Y) {
  Y <- t(true_Y)
  mean((Y > get_pm(m) - 1.96 * get_psd(m))
      & (Y < get_pm(m) + 1.96 * get_psd(m)))

# LFSR for FLASH --------------------------------------------------------
flash_lfsr <- function(fl_samp) {
  nsamp <- length(fl_samp)
  n <- nrow(fl_samp[[1]])
  p <- ncol(fl_samp[[1]])

  pp <- matrix(0, nrow=n, ncol=p)
  pn <- matrix(0, nrow=n, ncol=p)
  for (i in 1:nsamp) {
    pp <- pp + (fl_samp[[i]] > 0)
    pn <- pn + (fl_samp[[i]] < 0)
  1 - pmax(pp, pn) / nsamp

# Quantities for plotting ROC curves -----------------------------------
flash_roc <- function(fl, fl_samp, true_Y, step=0.01) {
  roc_data(flash_get_lf(fl), true_Y, flash_lfsr(fl_samp), step)

mash_roc <- function(m, true_Y, step=0.01) {
  roc_data(get_pm(m), t(true_Y), get_lfsr(m), step)

roc_data <- function(pm, true_Y, lfsr, step) {
  correct_sign <- pm * true_Y > 0
  is_null <- true_Y == 0
  n_nulls <- sum(is_null)
  n_nonnulls <- length(true_Y) - n_nulls

  ts <- seq(0, 1, by=step)
  tp <- rep(0, length(ts))
  fp <- rep(0, length(ts))

  for (t in 1:length(ts)) {
    signif <- lfsr <= ts[t]
    tp[t] <- sum(signif & correct_sign)
    fp[t] <- sum(signif & is_null)

  list(ts = ts, TP = tp, FP = fp, n_nulls = n_nulls, n_nonnulls = n_nonnulls)

# empirical false sign rate vs. local false sign rate
# efsr_by_lfsr <- function(pm, true_Y, lfsr, step) {
#   pred_signs <- sign(pm)
#   pred_zeros <- pred_signs == 0
#   pred_signs[pred_zeros] <- sample(c(0, 1), length(pred_zeros), replace=T)
#   gotitright <- (pred_signs == sign(true_Y))
#   nsteps <- floor(.5 / step)
#   efsr_by_lfsr <- rep(0, nsteps)
#   for (k in 1:nsteps) {
#     idx <- (lfsr >= (step * (k - 1)) & lfsr < (step * k))
#     efsr_by_lfsr[k] <- ifelse(sum(idx) == 0, NA,
#                               1 - sum(gotitright[idx]) / sum(idx))
#   }
#   efsr_by_lfsr
# }

…and some ugly functions that run everything and plot results.

run_sims <- function(sim_fn, nsims, plot_title, fpath) {
        if (nsims == 1) {
          res = run_one_sim(sim_fn)
        } else {
          res = run_many_sims(sim_fn, nsims)
  saveRDS(output_res_mat(res, plot_title), paste0(fpath, "res.rds"))
  if (!(plot_title == "Null simulation")) {
    png(paste0(fpath, "ROC.png"))
    plot_ROC(res, plot_title)
  png(paste0(fpath, "time.png"))

run_many_sims <- function(sim_fn, nsims) {
  res <- list()
  combined_res <- list()

  for (i in 1:nsims) {
    res[[i]] <- run_one_sim(sim_fn)
  list_elem <- names(res[[1]])
  for (elem in list_elem) {
    combined_res[[elem]] <- list()
    sub_elems <- names(res[[1]][[elem]])
    for (sub_elem in sub_elems) {
      tmp <- lapply(res, function(x) {x[[elem]][[sub_elem]]})
      combined_res[[elem]][[sub_elem]] <- Reduce(`+`, tmp)
      combined_res[[elem]][[sub_elem]] <- combined_res[[elem]][[sub_elem]] / nsims

run_one_sim <- function(sim_fn, Kmax = 10, nsamp=200) {
  data <-, list())

  # If there are no strong signals, trying to run ED throws an error, so
  #   we need to do some error handling to fit the MASH object
  try(mfit <- fit_mash(data$Y))
  if (!exists("mfit")) {
    mfit <- fit_mash(data$Y, ed=F)
    mfit$timing$ed = as.difftime(0, units="secs")

  flfit1 <- fit_flash(data$Y, Kmax, method = "OHL")
  flfit2 <- fit_flash(data$Y, Kmax, method = "OHF")
  flfit3 <- fit_flash(data$Y, Kmax, method = "OHFplus")

  message("Running MASH diagnostics")
  mres <- mash_diagnostics(mfit$m, data$true_Y)
  message("Running FLASH diagnostics")
  flres1 <- flash_diagnostics(flfit1$fl, data$Y, data$true_Y, nsamp)
  flres2 <- flash_diagnostics(flfit2$fl, data$Y, data$true_Y, nsamp)
  flres3 <- flash_diagnostics(flfit3$fl, data$Y, data$true_Y, nsamp)

  list(mash_timing = mfit$timing, mash_res = mres,
       flash_OHL_timing = flfit1$timing, flash_OHL_res = flres1,
       flash_OHF_timing = flfit2$timing, flash_OHF_res = flres2,
       flash_OHFplus_timing = flfit3$timing, flash_OHFplus_res = flres3)

output_res_mat <- function(res, caption) {
  data.frame(MASH = c(res$mash_res$MSE, res$mash_res$CI),
             FLASH_OHL = c(res$flash_OHL_res$MSE, res$flash_OHL_res$CI),
             FLASH_OHF = c(res$flash_OHF_res$MSE, res$flash_OHF_res$CI),
             FLASH_OHFplus = c(res$flash_OHFplus_res$MSE,
             row.names = c("MSE", "95% CI cov"))

plot_timing <- function(res) {
  data <- c(res$mash_timing$ed, res$mash_timing$mash,
            res$flash_OHL_timing$greedy, res$flash_OHL_timing$backfit,
            res$flash_OHF_timing$greedy, res$flash_OHF_timing$backfit,
  time_units <- units(data)
  data <- matrix(as.numeric(data), 2, 4)
  barplot(data, axes=T,
          main=paste("Average time to fit in", time_units),
          names.arg = c("MASH", "FLASH-OHL", "FLASH-OHF", "FLASH-OHF+"),
          legend.text = c("ED/Greedy", "MASH/Backfit"),
          ylim = c(0, max(colSums(data))*2))
  # (increasing ylim is easiest way to deal with legend getting in way)

plot_ROC <- function(res, main="ROC curve") {
  m_y <- res$mash_res$TP / res$mash_res$n_nonnulls
  m_x <- res$mash_res$FP / res$mash_res$n_nulls
  ohl_y <- res$flash_OHL_res$TP / res$flash_OHL_res$n_nonnulls
  ohl_x <- res$flash_OHL_res$FP / res$flash_OHL_res$n_nulls
  ohf_y <- res$flash_OHF_res$TP / res$flash_OHF_res$n_nonnulls
  ohf_x <- res$flash_OHF_res$FP / res$flash_OHF_res$n_nulls
  ohfp_y <- res$flash_OHFplus_res$TP / res$flash_OHFplus_res$n_nonnulls
  ohfp_x <- res$flash_OHFplus_res$FP / res$flash_OHFplus_res$n_nulls
  plot(m_x, m_y, xlim=c(0, 1), ylim=c(0, 1), type='l',
       xlab='FPR', ylab='TPR', main=main)
  lines(ohl_x, ohl_y, lty=2)
  lines(ohf_x, ohf_y, lty=3)
  lines(ohfp_x, ohfp_y, lty=4)
  legend("bottomright", c("MASH", "FLASH-OHL", "FLASH-OHF", "FLASH-OHF+"),

