Note that there are some explanatory texts on larger screens.

plurals
  1. PO
    text
    copied!<p>Here is a solution based on <code>RcppArmadillo</code>. It is over 100 times faster than the R implementation. First, the c++ implementation, which relies on <a href="http://gallery.rcpp.org/articles/dmvnorm_arma/" rel="noreferrer">this rcpp gallery example</a>.</p> <pre><code>// [[Rcpp::export]] arma::mat dmvnormderiv_arma(arma::mat x, arma::rowvec mean, arma::mat sigma, bool log = false) { // get result for mv normal arma::vec distval = Mahalanobis(x, mean, sigma); double logdet = sum(arma::log(arma::eig_sym(sigma))); double log2pi = std::log(2.0 * M_PI); arma::vec mvnorm = exp(-( (x.n_cols * log2pi + logdet + distval)/2)); // create output matrix with one column for each derivative int n = x.n_rows; arma::mat deriv; deriv.copy_size(x); for (int i=0; i &lt; n; i++) { deriv.row(i) = -1 * mvnorm(i) * trans(solve(sigma, trans(x.row(i) - mean))); } return(deriv); } </code></pre> <p>And two R implementations. One is pure R and one is based on <code>dmvnorm</code> in the package <code>mvtnorm</code>.</p> <pre><code>library('RcppArmadillo') library('mvtnorm') library('rbenchmark') sourceCpp('mvnorm.cpp') mvnormDeriv = function(X, mu=rep(0,ncol(X)), sigma=diag(ncol(X))) { fn = function(x) -1 * c((1/sqrt(det(2*pi*sigma))) * exp(-0.5*t(x-mu)%*%solve(sigma)%*%(x-mu))) * solve(sigma,(x-mu)) out = t(apply(X,1,fn)) return(out) } dmvnormDeriv = function(X, mean, sigma) { if (is.vector(X)) X &lt;- matrix(X, ncol = length(X)) if (missing(mean)) mean &lt;- rep(0, length = ncol(X)) if (missing(sigma)) sigma &lt;- diag(ncol(X)) n = nrow(X) mvnorm = dmvnorm(X, mean = mean, sigma = sigma) deriv = array(NA,c(n,ncol(X))) for (i in 1:n) deriv[i,] = -mvnorm[i] * solve(sigma,(X[i,]-mean)) return(deriv) } </code></pre> <p>Finally some benchmarks:</p> <pre><code>set.seed(123456789) sigma = rWishart(1, 2, diag(2))[,,1] means = rnorm(2) X = rmvnorm(10000, means, sigma) benchmark(dmvnormderiv_arma(X,means,sigma), mvnormDeriv(X,mu=means,sigma=sigma), dmvnormDeriv(X,mean=means,sigma=sigma), order="relative", replications=5)[,1:4] test replications elapsed 1 dmvnormderiv_arma(X, means, sigma) 5 0.016 3 dmvnormDeriv(X, mean = means, sigma = sigma) 5 2.118 2 mvnormDeriv(X, mu = means, sigma = sigma) 5 5.939 relative 1 1.000 3 132.375 2 371.187 </code></pre>
 

Querying!

 
Guidance

SQuiL has stopped working due to an internal error.

If you are curious you may find further information in the browser console, which is accessible through the devtools (F12).

Reload