[Rcpp-commits] r3726 - in pkg/RcppEigen/inst: include unitTests
noreply at r-forge.r-project.org
noreply at r-forge.r-project.org
Sat Jul 28 20:22:07 CEST 2012
Author: dmbates
Date: 2012-07-28 20:22:07 +0200 (Sat, 28 Jul 2012)
New Revision: 3726
Modified:
pkg/RcppEigen/inst/include/RcppEigenWrap.h
pkg/RcppEigen/inst/unitTests/runit.RcppEigen.R
Log:
More fixes to RcppEigenWrap.h and adjustment of tests accordingly.
The changes allow RowMajor matrices to be wrapped (thanks to Gael
Guennebaud) but cannot handle RowVector types. There will need to be
more template metaprogramming done to redirect the case of RowVector,
which cannot be changed to a ColMajor form.
Modified: pkg/RcppEigen/inst/include/RcppEigenWrap.h
===================================================================
--- pkg/RcppEigen/inst/include/RcppEigenWrap.h 2012-07-28 18:18:59 UTC (rev 3725)
+++ pkg/RcppEigen/inst/include/RcppEigenWrap.h 2012-07-28 18:22:07 UTC (rev 3726)
@@ -26,8 +26,8 @@
namespace RcppEigen{
-// template<>
- SEXP Eigen_cholmod_wrap(const Eigen::CholmodDecomposition<Eigen::SparseMatrix<double> >& obj) {
+ template<typename T>
+ SEXP Eigen_cholmod_wrap(const Eigen::CholmodDecomposition<Eigen::SparseMatrix<T> >& obj) {
const cholmod_factor* f = obj.factor();
if (f->minor < f->n)
throw std::runtime_error("CHOLMOD factorization was unsuccessful");
@@ -63,8 +63,8 @@
} /* namespace RcppEigen */
- template<>
- SEXP wrap(const Eigen::CholmodDecomposition<Eigen::SparseMatrix<double> >& obj) {
+ template<typename T>
+ SEXP wrap(const Eigen::CholmodDecomposition<Eigen::SparseMatrix<T> >& obj) {
return RcppEigen::Eigen_cholmod_wrap(obj);
}
@@ -81,10 +81,14 @@
template <typename T>
SEXP eigen_wrap_plain_dense( const T& obj, Rcpp::traits::true_type ){
int m = obj.rows(), n = obj.cols();
- SEXP ans = PROTECT(::Rcpp::wrap(obj.data(), obj.data() + m * n));
+ typename Eigen::internal::conditional<T::IsRowMajor,
+ Eigen::Matrix<typename T::Scalar,
+ T::RowsAtCompileTime,
+ T::ColsAtCompileTime,
+ Eigen::ColMajor>,
+ const T&>::type objCopy(obj);
+ SEXP ans = PROTECT(::Rcpp::wrap(objCopy.data(), objCopy.data() + m * n));
if( T::ColsAtCompileTime != 1 ) {
- if (T::IsRowMajor)
- throw std::invalid_argument("R requires column-major dense matrices");
SEXP dd = PROTECT(::Rf_allocVector(INTSXP, 2));
int *d = INTEGER(dd);
d[0] = m;
Modified: pkg/RcppEigen/inst/unitTests/runit.RcppEigen.R
===================================================================
--- pkg/RcppEigen/inst/unitTests/runit.RcppEigen.R 2012-07-28 18:18:59 UTC (rev 3725)
+++ pkg/RcppEigen/inst/unitTests/runit.RcppEigen.R 2012-07-28 18:22:07 UTC (rev 3726)
@@ -45,7 +45,7 @@
_["Col<int>"] = Eigen::MatrixXi::Zero(5, 1)
);
-// List rows = List::create(
+// List rows = List::create( // Do not try to wrap row vectors
// _["Row<complex>"] = Eigen::RowVectorXcd::Zero(5),
// _["Row<double>"] = Eigen::RowVectorXd::Zero(5),
// _["Row<float>"] = Eigen::RowVectorXf::Zero(5),
@@ -102,43 +102,43 @@
res <- fx()
- checkEquals( res[[1]][[1]], complex(5), msg = "VectorXcd::Zero(5)")
- checkEquals( res[[1]][[2]], double(5), msg = "VectorXd::Zero(5)")
- checkEquals( res[[1]][[3]], double(5), msg = "VectorXf::Zero(5)")
- checkEquals( res[[1]][[4]], integer(5), msg = "VectorXi::Zero(5)")
+ checkEquals( res[["vectors : VectorX<T>"]][["Vec<complex>"]], complex(5), msg = "VectorXcd::Zero(5)")
+ checkEquals( res[["vectors : VectorX<T>"]][["Vec<double>"]], double(5), msg = "VectorXd::Zero(5)")
+ checkEquals( res[["vectors : VectorX<T>"]][["Vec<float>"]], double(5), msg = "VectorXf::Zero(5)")
+ checkEquals( res[["vectors : VectorX<T>"]][["Vec<int>"]], integer(5), msg = "VectorXi::Zero(5)")
- checkEquals( res[[2]][[1]], (1+0i) * diag(nr=3L), msg = "MatrixXcd::Identity(3,3)")
- checkEquals( res[[2]][[2]], diag(nr=3L), msg = "MatrixXd::Identity(3,3)")
- checkEquals( res[[2]][[3]], diag(nr=3L), msg = "MatrixXf::Identity(3,3)")
- checkEquals( res[[2]][[4]], matrix(as.integer((diag(nr=3L))),nr=3L), msg = "MatrixXi::Identity(3,3)")
+ checkEquals( res[["matrices : MatrixX<T>"]][["Mat<complex>"]], (1+0i) * diag(nr=3L), msg = "MatrixXcd::Identity(3,3)")
+ checkEquals( res[["matrices : MatrixX<T>"]][["Mat<double>"]], diag(nr=3L), msg = "MatrixXd::Identity(3,3)")
+ checkEquals( res[["matrices : MatrixX<T>"]][["Mat<float>"]], diag(nr=3L), msg = "MatrixXf::Identity(3,3)")
+ checkEquals( res[["matrices : MatrixX<T>"]][["Mat<int>"]], matrix(as.integer((diag(nr=3L))),nr=3L), msg = "MatrixXi::Identity(3,3)")
-## checkEquals( res[[3]][[1]], matrix(complex(5), nr=1L), msg = "RowVectorXcd::Zero(5)" )
-## checkEquals( res[[3]][[1]], matrix(numeric(5), nr=1L), msg = "RowVectorXd::Zero(5)" )
-## checkEquals( res[[3]][[2]], matrix(numeric(5), nr=1L), msg = "RowVectorXf::Zero(5)" )
-## checkEquals( res[[3]][[3]], matrix(integer(5), nr=1L), msg = "RowVectorXi::Zero(5)" )
+# checkEquals( res[["rows : RowVectorX<T>"]][["Row<complex>"]], matrix(complex(5), nr=1L), msg = "RowVectorXcd::Zero(5)" )
+# checkEquals( res[["rows : RowVectorX<T>"]][["Row<double>"]], matrix(numeric(5), nr=1L), msg = "RowVectorXd::Zero(5)" )
+# checkEquals( res[["rows : RowVectorX<T>"]][["Row<float>"]], matrix(numeric(5), nr=1L), msg = "RowVectorXf::Zero(5)" )
+# checkEquals( res[["rows : RowVectorX<T>"]][["Row<int>"]], matrix(integer(5), nr=1L), msg = "RowVectorXi::Zero(5)" )
- checkEquals( res[[3]][[1]], as.matrix(complex(5)), msg = "MatrixXcd::Zero(5, 1)")
- checkEquals( res[[3]][[2]], as.matrix(numeric(5)), msg = "MatrixXd::Zero(5, 1)")
- checkEquals( res[[3]][[3]], as.matrix(numeric(5)), msg = "MatrixXf::Zero(5, 1)")
- checkEquals( res[[3]][[4]], as.matrix(integer(5)), msg = "MatrixXi::Zero(5, 1)")
+ checkEquals( res[["columns : MatrixX<T>"]][["Col<complex>"]], as.matrix(complex(5)), msg = "MatrixXcd::Zero(5, 1)")
+ checkEquals( res[["columns : MatrixX<T>"]][["Col<double>"]], as.matrix(numeric(5)), msg = "MatrixXd::Zero(5, 1)")
+ checkEquals( res[["columns : MatrixX<T>"]][["Col<float>"]], as.matrix(numeric(5)), msg = "MatrixXf::Zero(5, 1)")
+ checkEquals( res[["columns : MatrixX<T>"]][["Col<int>"]], as.matrix(integer(5)), msg = "MatrixXi::Zero(5, 1)")
- checkEquals( res[[4]][[1]], matrix(complex(9L), nc=3L), msg = "ArrayXXcd::Zero(3,3)")
- checkEquals( res[[4]][[2]], matrix(numeric(9L), nc=3L), msg = "ArrayXXd::Zero(3,3)")
- checkEquals( res[[4]][[3]], matrix(numeric(9L), nc=3L), msg = "ArrayXXf::Zero(3,3)")
- checkEquals( res[[4]][[4]], matrix(integer(9L), nc=3L), msg = "ArrayXXi::Zero(3,3)")
+ checkEquals( res[["arrays2d : ArrayXX<T>"]][["Arr2<complex>"]], matrix(complex(9L), nc=3L), msg = "ArrayXXcd::Zero(3,3)")
+ checkEquals( res[["arrays2d : ArrayXX<T>"]][["Arr2<double>"]], matrix(numeric(9L), nc=3L), msg = "ArrayXXd::Zero(3,3)")
+ checkEquals( res[["arrays2d : ArrayXX<T>"]][["Arr2<float>"]], matrix(numeric(9L), nc=3L), msg = "ArrayXXf::Zero(3,3)")
+ checkEquals( res[["arrays2d : ArrayXX<T>"]][["Arr2<int>"]], matrix(integer(9L), nc=3L), msg = "ArrayXXi::Zero(3,3)")
- checkEquals( res[[5]][[1]], complex(5), msg = "ArrayXcd::Zero(5)")
- checkEquals( res[[5]][[2]], double(5), msg = "ArrayXd::Zero(5)")
- checkEquals( res[[5]][[3]], double(5), msg = "ArrayXf::Zero(5)")
- checkEquals( res[[5]][[4]], integer(5), msg = "ArrayXi::Zero(5)")
+ checkEquals( res[["arrays1d : ArrayX<T>"]][["Arr1<complex>"]], complex(5), msg = "ArrayXcd::Zero(5)")
+ checkEquals( res[["arrays1d : ArrayX<T>"]][["Arr1<double>"]], double(5), msg = "ArrayXd::Zero(5)")
+ checkEquals( res[["arrays1d : ArrayX<T>"]][["Arr1<float>"]], double(5), msg = "ArrayXf::Zero(5)")
+ checkEquals( res[["arrays1d : ArrayX<T>"]][["Arr1<int>"]], integer(5), msg = "ArrayXi::Zero(5)")
oneTen <- seq(1, 10, length.out=6L)
- checkEquals( res[[6]][[1]], oneTen, msg = "Op_seq")
- checkEquals( res[[6]][[2]], log(oneTen), msg = "Op_log")
- checkEquals( res[[6]][[3]], exp(oneTen), msg = "Op_exp")
- checkEquals( res[[6]][[4]], sqrt(oneTen), msg = "Op_sqrt")
- checkEquals( res[[6]][[5]], cos(oneTen), msg = "Op_cos")
+ checkEquals( res[["operations : ArrayXd"]][["Op_seq"]], oneTen, msg = "Op_seq")
+ checkEquals( res[["operations : ArrayXd"]][["Op_log"]], log(oneTen), msg = "Op_log")
+ checkEquals( res[["operations : ArrayXd"]][["Op_exp"]], exp(oneTen), msg = "Op_exp")
+ checkEquals( res[["operations : ArrayXd"]][["Op_sqrt"]], sqrt(oneTen), msg = "Op_sqrt")
+ checkEquals( res[["operations : ArrayXd"]][["Op_cos"]], cos(oneTen), msg = "Op_cos")
}
More information about the Rcpp-commits
mailing list