Last updated: 2018-11-09

workflowr checks: (Click a bullet for more information)
Expand here to see past versions:


This script implements the “Gaussian variance estimation” simulation experiments in the paper. In particular, we compare the Mean Field Variational Bayes (MFVB) method against SMASH in two scenarios.

The figure and table generated at the end of this script should match up with the figure and table shown in the paper.

Running the code could take several hours to complete as it runs the two methods on 100 simulated data sets for each of the two scenarios.

We thank M. Menictas & M. Wand for generously sharing code that was used to implement these experiments.

Initial setup instructions

To run this example on your own computer, please follow these setup instructions. These instructions assume you already have R and/or RStudio installed on your computer.

First, download or clone the [git repository][smash-github] on your computer.

Launch R, and change the working directory to be the “analysis” folder inside your local copy of the git repository.

Finally, install the smashr package from GitHub:

devtools::install_github("stephenslab/smashr",upgrade_dependencies = FALSE)

See the “Session Info” at the bottom for the versions of the software and R packages that were used to generate the results shown below.

Set up R environment

Load the smashr package, as well as some functions used in the analysis below.

library(smashr)
source("../code/mfvb.R")

Analysis settings

Specify the number of data sets simulated in the first and second simulation scenarios.

nsim1 <- 10
nsim2 <- 10

Next, specify the hyperparameters used in running the MFVB method.

Au.hyp      <- 1e5
Av.hyp      <- 1e5
sigsq.gamma <- 1e10
sigsq.beta  <- 1e10

These variables specify some colours used in the plots.

mainCol <- "darkslateblue"
ptCol   <- "paleturquoise3"
lineCol <- "skyblue"
axisCol <- "black"

These are additional plotting parameters.

cex.pt      <- 0.75
cex.mainVal <- 1.7
cex.labVal  <- 1.3
xlabVal     <- "x"

Plot mean and variance functions used to simulate data

Compare this plot against the one shown in Fig. 4 of the paper.

xgrid <- (0:10000)/10000
plot(xgrid,fTrue(xgrid),type = "l",ylim = c(-5,5),ylab = "y",xlab = "X",
     lwd = 2)
lines(xgrid,fTrue(xgrid) + 2*sqrt(gTrue(xgrid)),col = "darkorange",lwd = 2)
lines(xgrid,fTrue(xgrid) - 2*sqrt(gTrue(xgrid)),col = "darkorange",lwd = 2)

First simulation scenario: unevenly spaced data

In the first scenario, we simulated unevenly spaced data points, and assessed accuracy by computing the mean of the squared errors (MSE) evaluated at 201 equally spaced points.

mse.mu.uneven.mfvb  <- 0
mse.mu.uneven.smash <- 0
mse.sd.uneven.mfvb  <- 0
mse.sd.uneven.smash <- 0

Run the SMASH and MFVB methods for each simulated data set.

cat(sprintf("Running %d simulations: ",nsim1))
for (j in 1:nsim1) {
  cat(sprintf("%d ",j))
    
  # SIMULATE DATA
  set.seed(3*j)
  n     <- 500
  xOrig <- runif(n)
  set.seed(3*j)
  yOrig <- fTrue(xOrig) + sqrt(exp(loggTrue(xOrig)))*rnorm(n)

  aOrig  <- min(xOrig)
  bOrig  <- max(xOrig)
  mean.x <- mean(xOrig)
  sd.x   <- sd(xOrig)
  mean.y <- mean(yOrig)
  sd.y   <- sd(yOrig)

  a <- (aOrig - mean.x)/sd.x
  b <- (bOrig - mean.x)/sd.x
  x <- (xOrig - mean.x)/sd.x
  y <- (yOrig - mean.y)/sd.y

  numIntKnotsU <- 17
  intKnotsU <- quantile(x,seq(0,1,length=numIntKnotsU+2)[-c(1,numIntKnotsU+2)])
  Zu        <- ZOSull(x,intKnots=intKnotsU,range.x=c(a,b))
  numKnotsU <- ncol(Zu)

  numIntKnotsV <- numIntKnotsU 
  intKnotsV <-
    quantile(x,seq(0,1,length = numIntKnotsV + 2)[-c(1,numIntKnotsV+2)])
  Zv        <- ZOSull(x,intKnots=intKnotsV,range.x=c(a,b))
  numKnotsV <- ncol(Zv) 

  # Run Mean Field Variational Bayes (MFVB)
  # ---------------------------------------
  X     <- cbind(rep(1,n),x)
  Cumat <- cbind(X,Zu)
  Cvmat <- cbind(X,Zv)
  ncX   <- ncol(X)
  ncZu  <- ncol(Zu)
  ncZv  <- ncol(Zv)
  ncCu  <- ncol(Cumat)
  ncCv  <- ncol(Cvmat)

  MFVBfit <- meanVarMFVB(y,X,ncZu,ncZv,Au.hyp,Av.hyp,
                         sigsq.gamma,sigsq.beta)

  ng     <- 201
  xgOrig <- seq(aOrig,bOrig,length=ng)
  xg     <- (xgOrig - mean.x)/sd.x
  Xg     <- cbind(rep(1,ng),xg)
  Zug    <- ZOSull(xg,intKnots=intKnotsU,range.x=c(a,b))
  Cug    <- cbind(Xg,Zug)
  Zvg    <- ZOSull(xg,intKnots=intKnotsV,range.x=c(a,b))
  Cvg    <- cbind(Xg,Zvg)

  mu.q.nu       <- MFVBfit$mu.q.nu
  mu.q.omega    <- MFVBfit$mu.q.omega
  Sigma.q.nu    <- MFVBfit$Sigma.q.nu
  Sigma.q.omega <- MFVBfit$Sigma.q.omega

  fhatMFVBg        <- Cug%*%mu.q.nu
  fhatMFVBgOrig    <- fhatMFVBg*sd.y + mean.y
  logghatMFVBg     <- Cvg%*%mu.q.omega 
  logghatMFVBgOrig <- logghatMFVBg + 2*log(sd.y)

  sdloggMFVBgOrig      <- sqrt(diag(Cvg%*%Sigma.q.omega%*%t(Cvg))) 
  credLowloggMFVBgOrig <- logghatMFVBgOrig - qnorm(0.975)*sdloggMFVBgOrig
  credUpploggMFVBgOrig <- logghatMFVBgOrig + qnorm(0.975)*sdloggMFVBgOrig

  sqrtghatMFVBg     <- exp(0.5*Cvg %*% mu.q.omega 
                           + 0.125*diag(Cvg%*%Sigma.q.omega%*%t(Cvg)))
  sqrtghatMFVBgOrig <- sqrtghatMFVBg*sd.y

  # RUN SMASH
  x.mod <- unique(sort(xOrig))
  y.mod <- 0
  for(i in 1:length(x.mod))
    y.mod[i] <- median(yOrig[xOrig == x.mod[i]])

  y.exp   <- c(y.mod,y.mod[length(y.mod):(2*length(y.mod)-2^9+1)])
  y.final <- c(y.exp,y.exp[length(y.exp):1])

  mu.est  <- smash.gaus(y.final,filter.number=1,family="DaubExPhase")
  var.est <- smash.gaus(y.final,v.est=TRUE)
  mu.est  <- mu.est[1:500]
  var.est <- var.est[1:500]

  mu.est.inter  <- approx(x.mod,mu.est,xgOrig,'linear')$y
  var.est.inter <- approx(x.mod,var.est,xgOrig,'linear')$y

  mse.mu.uneven.mfvb[j] <- mean((fhatMFVBgOrig - fTrue(xgOrig))^2)
  mse.sd.uneven.mfvb[j] <- mean((sqrtghatMFVBgOrig-exp((loggTrue(xgOrig))/2))^2)

  mu.est  <- smash.gaus(y.final,filter.number=8,family="DaubLeAsymm")
  var.est <- smash.gaus(y.final,v.est=TRUE,v.basis=TRUE,filter.number=8,
                        family="DaubLeAsymm")
  
  mu.est  <- mu.est[1:500]
  var.est <- var.est[1:500]

  mu.est.inter       <- approx(x.mod,mu.est,xgOrig,'linear')$y
  var.est.inter      <- approx(x.mod,var.est,xgOrig,'linear')$y

  mse.mu.uneven.s8[j]=mean((mu.est.inter-fTrue(xgOrig))^2)
  mse.sd.uneven.s8[j]=mean((sqrt(var.est.inter)-exp((loggTrue(xgOrig))/2))^2)
}
cat("\n")

Second simulation scenario: evenly spaced points

mse.mu.even.mfvb <- 0
mse.mu.even.s8   <- 0

mse.sd.even.mfvb <- 0
mse.sd.even.s8   <- 0
lmse.sd.even.mfvb <- 0
lmse.sd.even.s8   <- 0

# Repeat for each data set simulated in the second simulation scenario.
cat(sprintf("Running %d simulations: ",nsim2))
for (j in 1:nsim2) {
  cat(sprintf("%d ",j))

  # Create the simulated data set
  # -----------------------------
  n      <- 2^10
  xOrig  <- (1:n)/n
  set.seed(30*j)
  yOrig  <- fTrue(xOrig) + sqrt(exp(loggTrue(xOrig)))*rnorm(n)
  aOrig  <- min(xOrig)
  bOrig  <- max(xOrig)
  mean.x <- mean(xOrig)
  sd.x   <- sd(xOrig)
  mean.y <- mean(yOrig)
  sd.y   <- sd(yOrig)

  a <- (aOrig - mean.x)/sd.x
  b <- (bOrig - mean.x)/sd.x
  x <- (xOrig - mean.x)/sd.x
  y <- (yOrig - mean.y)/sd.y

  numIntKnotsU <- 17
  intKnotsU <- quantile(x,seq(0,1,length=numIntKnotsU+2)[-c(1,numIntKnotsU+2)])
  Zu        <- ZOSull(x,intKnots=intKnotsU,range.x=c(a,b))
  numKnotsU <- ncol(Zu)

  numIntKnotsV <- numIntKnotsU 
  intKnotsV <- quantile(x,seq(0,1,length=numIntKnotsV+2)[-c(1,numIntKnotsV+2)])
  Zv        <- ZOSull(x,intKnots=intKnotsV,range.x=c(a,b))
  numKnotsV <- ncol(Zv) 

  # Run Mean Field Variational Bayes (MFVB)
  # ---------------------------------------
  X     <- cbind(rep(1,n),x)
  Cumat <- cbind(X,Zu)
  Cvmat <- cbind(X,Zv)
  ncX   <- ncol(X)
  ncZu  <- ncol(Zu)
  ncZv  <- ncol(Zv)
  ncCu  <- ncol(Cumat)
  ncCv  <- ncol(Cvmat)

  MFVBfit <- meanVarMFVB(y,X,ncZu,ncZv,Au.hyp,Av.hyp,
                       sigsq.gamma,sigsq.beta)

  ng     <- 2^10
  xgOrig <- seq(aOrig,bOrig,length=ng)
  xg  <- (xgOrig - mean.x)/sd.x
  Xg  <- cbind(rep(1,ng),xg)
  Zug <- ZOSull(xg,intKnots=intKnotsU,range.x=c(a,b))
  Cug <- cbind(Xg,Zug)
  Zvg <- ZOSull(xg,intKnots=intKnotsV,range.x=c(a,b))
  Cvg <- cbind(Xg,Zvg)

  mu.q.nu       <- MFVBfit$mu.q.nu
  mu.q.omega    <- MFVBfit$mu.q.omega
  Sigma.q.nu    <- MFVBfit$Sigma.q.nu
  Sigma.q.omega <- MFVBfit$Sigma.q.omega

  # Get the mean function estimate.
  fhatMFVBg     <- Cug %*% mu.q.nu
  fhatMFVBgOrig <- fhatMFVBg*sd.y + mean.y

  logghatMFVBg     <- Cvg%*%mu.q.omega 
  logghatMFVBgOrig <- logghatMFVBg + 2*log(sd.y)

  sdloggMFVBgOrig      <- sqrt(diag(Cvg%*%Sigma.q.omega%*%t(Cvg))) 
  credLowloggMFVBgOrig <- logghatMFVBgOrig - qnorm(0.975)*sdloggMFVBgOrig
  credUpploggMFVBgOrig <- logghatMFVBgOrig + qnorm(0.975)*sdloggMFVBgOrig

  sqrtghatMFVBg     <- exp(0.5*Cvg%*%mu.q.omega 
                           + 0.125*diag(Cvg%*%Sigma.q.omega%*%t(Cvg)))
  sqrtghatMFVBgOrig <- sqrtghatMFVBg*sd.y

  # Run SMASH
  # ---------
  mu.est  <- smash.gaus(yOrig,filter.number=1,family="DaubExPhase")
  var.est <- smash.gaus(yOrig,v.est=TRUE)
  
  mse.mu.even.mfvb[j] <- mean((fhatMFVBgOrig-fTrue(xgOrig))^2)
  mse.mu.even.haar[j] <- mean((mu.est-fTrue(xgOrig))^2)

  mse.sd.even.mfvb[j]  <- mean((sqrtghatMFVBgOrig-exp((loggTrue(xgOrig))/2))^2)
  mse.sd.even.haar[j]  <- mean((sqrt(var.est)-exp((loggTrue(xgOrig))/2))^2)
  lmse.sd.even.mfvb[j] <- mean((log(sqrtghatMFVBgOrig)-(loggTrue(xgOrig))/2)^2)
  lmse.sd.even.haar[j] <- mean((log(sqrt(var.est))-(loggTrue(xgOrig))/2)^2)

  mu.est       <- smash.gaus(yOrig,filter.number=8,family="DaubLeAsymm")
  var.est      <- smash.gaus(yOrig,v.est=TRUE,v.basis=TRUE,filter.number=8,
                             family="DaubLeAsymm")
  var.est.s8.j <- smash.gaus(yOrig,v.est=TRUE,v.basis=TRUE,jash=TRUE,
                             filter.number=8,family="DaubLeAsymm")
  var.est.j    <- smash.gaus(yOrig,v.est=TRUE,jash=TRUE)

  mse.mu.even.s8[j]   <-mean((mu.est-fTrue(xgOrig))^2)
  mse.sd.even.s8[j]   <-mean((sqrt(var.est)-exp((loggTrue(xgOrig))/2))^2)
  mse.sd.even.j[j]    <-mean((sqrt(var.est.j)-exp((loggTrue(xgOrig))/2))^2)
  mse.sd.even.s8.j[j] <-mean((sqrt(var.est.s8.j)-exp((loggTrue(xgOrig))/2))^2)
  lmse.sd.even.s8[j]  <-mean((log(sqrt(var.est))-(loggTrue(xgOrig))/2)^2)
  lmse.sd.even.j[j]   <-mean((log(sqrt(var.est.j))-(loggTrue(xgOrig))/2)^2)
  lmse.sd.even.s8.j[j]<-mean((log(sqrt(var.est.s8.j))-(loggTrue(xgOrig))/2)^2)
}
cat("\n")

# SUMMARIZE RESULTS
# -----------------
# Compare this table against Table 1 in the paper.
cat("Summarizing results of simulations.\n")
cat(sprintf("MSE averaged across %d simulations in Scenario 1,\n",nsim1))
cat(sprintf("and averaged across %d simulations in Scenario 2:\n",nsim2))
mse.table <- rbind(c(mean(mse.mu.uneven.mfvb),mean(mse.sd.uneven.mfvb),
                     mean(mse.mu.even.mfvb),mean(mse.sd.even.mfvb)),
                   c(mean(mse.mu.uneven.s8),mean(mse.sd.uneven.s8),
                     mean(mse.mu.even.s8),mean(mse.sd.even.s8)))
rownames(mse.table) <- c("MFVB","SMASH")
colnames(mse.table) <- c("mean","sd","mean","sd")
cat("         Scenario 1      Scenario 2   \n")
cat("      --------------- ----------------\n")
print(mse.table)

Session information

sessionInfo()
# R version 3.4.3 (2017-11-30)
# Platform: x86_64-apple-darwin15.6.0 (64-bit)
# Running under: macOS High Sierra 10.13.6
# 
# Matrix products: default
# BLAS: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRblas.0.dylib
# LAPACK: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRlapack.dylib
# 
# locale:
# [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
# 
# attached base packages:
# [1] stats     graphics  grDevices utils     datasets  methods   base     
# 
# other attached packages:
# [1] smashr_1.2-0
# 
# loaded via a namespace (and not attached):
#  [1] Rcpp_0.12.19      knitr_1.20        whisker_0.3-2    
#  [4] magrittr_1.5      workflowr_1.1.1   MASS_7.3-48      
#  [7] pscl_1.5.2        doParallel_1.0.11 SQUAREM_2017.10-1
# [10] lattice_0.20-35   foreach_1.4.4     ashr_2.2-23      
# [13] stringr_1.3.1     caTools_1.17.1    tools_3.4.3      
# [16] parallel_3.4.3    grid_3.4.3        data.table_1.11.4
# [19] R.oo_1.21.0       git2r_0.23.0      iterators_1.0.9  
# [22] htmltools_0.3.6   yaml_2.2.0        rprojroot_1.3-2  
# [25] digest_0.6.17     Matrix_1.2-12     bitops_1.0-6     
# [28] codetools_0.2-15  R.utils_2.6.0     evaluate_0.11    
# [31] rmarkdown_1.10    wavethresh_4.6.8  stringi_1.2.4    
# [34] compiler_3.4.3    backports_1.1.2   R.methodsS3_1.7.1
# [37] truncnorm_1.0-8

This reproducible R Markdown analysis was created with workflowr 1.1.1