Last updated: 2018-10-07

workflowr checks: (Click a bullet for more information)
Expand here to see past versions:


The problem

Model: \(X\sim Poi(\mu)\) and define \(y=\log m+\frac{x-m}{m}\)

Previously, we used \(m=\)ash posterior mean of x. The problem is that shrinkage effect is too strong that for large observations \(x\), the approximated normal data points are too large. Hence, after taking exponential of estimated normal mean, the estimation ‘blow up’.

Now we try to do Taylor series expansion around \(\log x\)(MLE) for non-zero \(x\) and around ash posterior mean around zero \(x\). In other words, now \(m=\)ash posterior mean of zero \(x\)s and \(m=x\) for non-zero \(x\)s.

Spike mean function

library(ashr)
library(smashrgen)
spike.f = function(x) (0.75 * exp(-500 * (x - 0.23)^2) + 1.5 * exp(-2000 * (x - 0.33)^2) + 3 * exp(-8000 * (x - 0.47)^2) + 2.25 * exp(-16000 * 
    (x - 0.69)^2) + 0.5 * exp(-32000 * (x - 0.83)^2))
n = 256
t = 1:n/n
m = spike.f(t)

m=m*2+0.1
range(m)
[1] 0.100000 6.025467
sig=0

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)

x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

par(mfrow=c(2,2))

plot(x,col='grey80',ylab='',xlab='',main='n=256,nugget=0')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))
#################

sig=0.1

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)


x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

plot(x,col='grey80',ylab='',xlab='',main='n=256,nugget=0.1')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))

#################

sig=1

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)

x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

plot(x,col='grey80',ylab='',xlab='',main='n=256,nugget=1')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))

plot(x,col='grey80',ylab='',xlab='',main='Previous verison using ash posterior mean, nugget=1')
lines(m,col='grey60')
lines(smash_gen_lite(x))
legend('topleft',c('data','true mean','fit'),lty=c(0,1,1),pch=c(1,NA,NA),col=c('grey80','grey60',1))

n = 512
t = 1:n/n
m = spike.f(t)

m=m*2+0.1
range(m)
[1] 0.100000 6.076316
sig=0

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)



x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

par(mfrow=c(2,2))

plot(x,col='grey80',ylab='',xlab='',main='n=512,nugget=0')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))
#################

sig=0.1

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)



x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

plot(x,col='grey80',ylab='',xlab='',main='n=512,nugget=0.1')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))

#################

sig=1

set.seed(12345)
lambda=exp(log(m)+rnorm(n,0,sig))
x=rpois(n,lambda)



x.ash=ash(rep(0,n),1,lik=lik_pois(x))$result$PosteriorMean
m.hat=x.ash
m.hat[which(x!=0)]=(x[which(x!=0)])
y=log(m.hat)+(x-m.hat)/m.hat
m.tilde=exp(smash.gaus(y,sigma = sqrt(sig^2+1/m.hat)))
m.tilde2=exp(smash.gaus(y))

plot(x,col='grey80',ylab='',xlab='',main='n=512,nugget=1')
lines(m,col='grey60')
lines(m.tilde,col=2)
lines(m.tilde2,col=4)

legend('topleft',c('data','true mean','smashgen-known var','smashgen-unknown var'),lty=c(0,1,1,1),pch=c(1,NA,NA,NA),col=c('grey80','grey60',2,4))

plot(x,col='grey80',ylab='',xlab='',main='Previous verison using ash posterior mean, nugget=1')
lines(m,col='grey60')
lines(smash_gen_lite(x))
legend('topleft',c('data','true mean','fit'),lty=c(0,1,1),pch=c(1,NA,NA),col=c('grey80','grey60',1))

Session information

sessionInfo()
R version 3.5.1 (2018-07-02)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS High Sierra 10.13.6

Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] smashrgen_0.1.0  wavethresh_4.6.8 MASS_7.3-50      caTools_1.17.1.1
[5] smashr_1.2-0     ashr_2.2-7      

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.18      compiler_3.5.1    git2r_0.23.0     
 [4] workflowr_1.1.1   R.methodsS3_1.7.1 R.utils_2.7.0    
 [7] bitops_1.0-6      iterators_1.0.10  tools_3.5.1      
[10] digest_0.6.17     evaluate_0.11     lattice_0.20-35  
[13] Matrix_1.2-14     foreach_1.4.4     yaml_2.2.0       
[16] parallel_3.5.1    stringr_1.3.1     knitr_1.20       
[19] REBayes_1.3       rprojroot_1.3-2   grid_3.5.1       
[22] data.table_1.11.6 rmarkdown_1.10    magrittr_1.5     
[25] whisker_0.3-2     backports_1.1.2   codetools_0.2-15 
[28] htmltools_0.3.6   assertthat_0.2.0  stringi_1.2.4    
[31] Rmosek_8.0.69     doParallel_1.0.14 pscl_1.5.2       
[34] truncnorm_1.0-8   SQUAREM_2017.10-1 R.oo_1.22.0      

This reproducible R Markdown analysis was created with workflowr 1.1.1