Note that there are some explanatory texts on larger screens.

plurals
  1. POFast sampling from Truncated Normal Distribution using Rcpp and openMP
    primarykey
    data
    text
    <p>UPDATE:</p> <p>I tried to implement Dirk's suggestions. Comments? I am busy right now at JSM, but I'd like to get some feedback before knitting an Rmd for the gallery. I switched back from Armadillo to normal Rcpp, as it didn't add any value. Scalar versions with R:: are quite nice. I should maybe put in a parameter n for the number of draws if mean/sd are entered as scalar, not as vectors of the desired output length.</p> <hr> <p>There are lots of MCMC application that require drawing samples from truncated Normal distributions. I built on an existing implementation of the TN and added parallel computation to it.</p> <p>Issues:</p> <ol> <li>Does anyone see further potential speed improvements? <strike>In the last case from the benchmark, rtruncnorm is sometimes faster.</strike> The Rcpp implementation is always faster than existing packages, but can it be improved even further?</li> <li>I ran it within a complex model I can't share, and my R session crashed. However, I cannot systematically reproduce it, so it could have been another part of the code. If someone is working with the TN, please test it and let me know. Update: I haven't had issues with the updated code, but let me know.</li> </ol> <p>How I put things together: To my knowledge, the fastest implementation is not on CRAN, but the source code can be downloaded <a href="http://www.stat.osu.edu/~pfc/software/truncatedNormals/" rel="nofollow" title="at OSU's stat dept">OSU stat</a>. Competing implementations in <em>msm</em> and <em>truncorm</em> were slower in my benchmarks. The trick is to efficiently adjust proposal distributions, where the Exponential works nicely for the tails of the truncated Normal. So I took Chris' code, "Rcpp'ed" it and added some openMP spice to it. The dynamic schedule is optimal here, as sampling can take more or less time depending on the boundaries. One thing I found nasty: lots of the statistical distributions are based on the NumericVector type, when I wanted to work with doubles. I just coded my way around that.</p> <p>Heres the Rcpp code:</p> <pre><code>#include &lt;Rcpp.h&gt; #include &lt;omp.h&gt; // norm_rs(a, b) // generates a sample from a N(0,1) RV restricted to be in the interval // (a,b) via rejection sampling. // ====================================================================== // [[Rcpp::export]] double norm_rs(double a, double b) { double x; x = Rf_rnorm(0.0, 1.0); while( (x &lt; a) || (x &gt; b) ) x = norm_rand(); return x; } // half_norm_rs(a, b) // generates a sample from a N(0,1) RV restricted to the interval // (a,b) (with a &gt; 0) using half normal rejection sampling. // ====================================================================== // [[Rcpp::export]] double half_norm_rs(double a, double b) { double x; x = fabs(norm_rand()); while( (x&lt;a) || (x&gt;b) ) x = fabs(norm_rand()); return x; } // unif_rs(a, b) // generates a sample from a N(0,1) RV restricted to the interval // (a,b) using uniform rejection sampling. // ====================================================================== // [[Rcpp::export]] double unif_rs(double a, double b) { double xstar, logphixstar, x, logu; // Find the argmax (b is always &gt;= 0) // This works because we want to sample from N(0,1) if(a &lt;= 0.0) xstar = 0.0; else xstar = a; logphixstar = R::dnorm(xstar, 0.0, 1.0, 1.0); x = R::runif(a, b); logu = log(R::runif(0.0, 1.0)); while( logu &gt; (R::dnorm(x, 0.0, 1.0,1.0) - logphixstar)) { x = R::runif(a, b); logu = log(R::runif(0.0, 1.0)); } return x; } // exp_rs(a, b) // generates a sample from a N(0,1) RV restricted to the interval // (a,b) using exponential rejection sampling. // ====================================================================== // [[Rcpp::export]] double exp_rs(double a, double b) { double z, u, rate; // Rprintf("in exp_rs"); rate = 1/a; //1/a // Generate a proposal on (0, b-a) z = R::rexp(rate); while(z &gt; (b-a)) z = R::rexp(rate); u = R::runif(0.0, 1.0); while( log(u) &gt; (-0.5*z*z)) { z = R::rexp(rate); while(z &gt; (b-a)) z = R::rexp(rate); u = R::runif(0.0,1.0); } return(z+a); } // rnorm_trunc( mu, sigma, lower, upper) // // generates one random normal RVs with mean 'mu' and standard // deviation 'sigma', truncated to the interval (lower,upper), where // lower can be -Inf and upper can be Inf. //====================================================================== // [[Rcpp::export]] double rnorm_trunc (double mu, double sigma, double lower, double upper) { int change; double a, b; double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725; double z, tmp, lograt; change = 0; a = (lower - mu)/sigma; b = (upper - mu)/sigma; // First scenario if( (a == R_NegInf) || (b == R_PosInf)) { if(a == R_NegInf) { change = 1; a = -b; b = R_PosInf; } // The two possibilities for this scenario if(a &lt;= 0.45) z = norm_rs(a, b); else z = exp_rs(a, b); if(change) z = -z; } // Second scenario else if((a * b) &lt;= 0.0) { // The two possibilities for this scenario if((R::dnorm(a, 0.0, 1.0,1.0) &lt;= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) &lt;= logt1)) { z = norm_rs(a, b); } else z = unif_rs(a,b); } // Third scenario else { if(b &lt; 0) { tmp = b; b = -a; a = -tmp; change = 1; } lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0); if(lograt &lt;= logt2) z = unif_rs(a,b); else if((lograt &gt; logt1) &amp;&amp; (a &lt; t3)) z = half_norm_rs(a,b); else z = exp_rs(a,b); if(change) z = -z; } double output; output = sigma*z + mu; return (output); } // rtnm( mu, sigma, lower, upper, cores) // // generates one random normal RVs with mean 'mu' and standard // deviation 'sigma', truncated to the interval (lower,upper), where // lower can be -Inf and upper can be Inf. // mu, sigma, lower, upper are vectors, and vectorized calls of this function // speed up computation // cores is an intege, representing the number of cores to be used in parallel //====================================================================== // [[Rcpp::export]] Rcpp::NumericVector rtnm(Rcpp::NumericVector mus, Rcpp::NumericVector sigmas, Rcpp::NumericVector lower, Rcpp::NumericVector upper, int cores){ omp_set_num_threads(cores); int nobs = mus.size(); Rcpp::NumericVector out(nobs); double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725; double a,b, z, tmp, lograt; int change; #pragma omp parallel for schedule(dynamic) for(int i=0;i&lt;nobs;i++) { a = (lower(i) - mus(i))/sigmas(i); b = (upper(i) - mus(i))/sigmas(i); change=0; // First scenario if( (a == R_NegInf) || (b == R_PosInf)) { if(a == R_NegInf) { change = 1; a = -b; b = R_PosInf; } // The two possibilities for this scenario if(a &lt;= 0.45) z = norm_rs(a, b); else z = exp_rs(a, b); if(change) z = -z; } // Second scenario else if((a * b) &lt;= 0.0) { // The two possibilities for this scenario if((R::dnorm(a, 0.0, 1.0,1.0) &lt;= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) &lt;= logt1)) { z = norm_rs(a, b); } else z = unif_rs(a,b); } // Third scenario else { if(b &lt; 0) { tmp = b; b = -a; a = -tmp; change = 1; } lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0); if(lograt &lt;= logt2) z = unif_rs(a,b); else if((lograt &gt; logt1) &amp;&amp; (a &lt; t3)) z = half_norm_rs(a,b); else z = exp_rs(a,b); if(change) z = -z; } out(i)=sigmas(i)*z + mus(i); } return(out); } </code></pre> <p>And here is the benchmark: </p> <pre><code>libs=c("truncnorm","msm","inline","Rcpp","RcppArmadillo","rbenchmark") if( sum(!(libs %in% .packages(all.available = TRUE)))&gt;0){ install.packages(libs[!(libs %in% .packages(all.available = TRUE))])} for(i in 1:length(libs)) {library(libs[i],character.only = TRUE,quietly=TRUE)} #needed for openMP parallel Sys.setenv("PKG_CXXFLAGS"="-fopenmp") Sys.setenv("PKG_LIBS"="-fopenmp") #no of cores for openMP version cores = 4 #surce code from same dir Rcpp::sourceCpp('truncnorm.cpp') #sample size nn=1000000 bb= 100 aa=-100 benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] aa=0 benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] aa=2 benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] aa=50 benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] </code></pre> <p>Several benchmark runs are necessary as the speed depends on the upper/lower boundaries. For different cases, different parts of the algorithm kick in.</p>
    singulars
    1. This table or related slice is empty.
    1. This table or related slice is empty.
    plurals
    1. This table or related slice is empty.
    1. This table or related slice is empty.
    1. This table or related slice is empty.
 

Querying!

 
Guidance

SQuiL has stopped working due to an internal error.

If you are curious you may find further information in the browser console, which is accessible through the devtools (F12).

Reload