[Rcpp-commits] r2412 - in pkg/Irwls: . man src tests

noreply at r-forge.r-project.org noreply at r-forge.r-project.org
Sat Nov 6 23:02:50 CET 2010


Author: jmc
Date: 2010-11-06 23:02:50 +0100 (Sat, 06 Nov 2010)
New Revision: 2412

Added:
   pkg/Irwls/man/mIrwls-module.Rd
Modified:
   pkg/Irwls/NAMESPACE
   pkg/Irwls/src/irwls.cpp
   pkg/Irwls/tests/trivial.R
Log:
use sugar in fit(), add minimal documentaiton

Modified: pkg/Irwls/NAMESPACE
===================================================================
--- pkg/Irwls/NAMESPACE	2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/NAMESPACE	2010-11-06 22:02:50 UTC (rev 2412)
@@ -1,3 +1,4 @@
 useDynLib(Irwls)
 exportPattern("^[[:alpha:]]+")
+importFrom(Rcpp, Module)
 importClassesFrom( Rcpp, "C++Object", "C++Class", "Module" )

Added: pkg/Irwls/man/mIrwls-module.Rd
===================================================================
--- pkg/Irwls/man/mIrwls-module.Rd	                        (rev 0)
+++ pkg/Irwls/man/mIrwls-module.Rd	2010-11-06 22:02:50 UTC (rev 2412)
@@ -0,0 +1,19 @@
+\name{mIrwls}
+\alias{mIrwls}
+\title{
+	Rcpp module: mIrwls
+}
+\description{
+	Rcpp module for iterative weighted least-squares
+}
+\details{
+	The module contains the following items: 
+	
+	 
+	
+	classes: \describe{
+        \item{cppIrwls}{ ~~ description of class cppIrwls ~~ }
+		}
+}
+
+\keyword{datasets}

Modified: pkg/Irwls/src/irwls.cpp
===================================================================
--- pkg/Irwls/src/irwls.cpp	2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/src/irwls.cpp	2010-11-06 22:02:50 UTC (rev 2412)
@@ -5,47 +5,52 @@
 
 class Irwls{
 public:
-    NumericMatrix x, xw;
-    NumericVector y, yw, wrt;
-
     Irwls(SEXP xr, SEXP yr) {
 	x = NumericMatrix(xr);
 	y = NumericVector(yr);
-	xw = NumericMatrix(x.nrow(), x.ncol());
+	xw = NumericVector(x.nrow()*x.ncol());
 	yw = NumericVector(y.size());
 	wrt = NumericVector(y.size());
     }
 
-    SEXP fit(SEXP wR) {
+    NumericVector fit(SEXP wR) {
 	compute_and_apply_weights(wR);
 	return do_fit();
     }
+
+private:
+    NumericMatrix x;
+    NumericVector y, yw, wrt, xw;
+
 	
-    SEXP do_fit() {
-	int n = xw.nrow(), p = xw.ncol();
+    NumericVector do_fit() {
+	int n = x.nrow(), p = x.ncol();
 	arma::mat X(xw.begin(), n, p, false);
 	arma::colvec y(yw.begin(), n, false);
 	arma::vec coefa = arma::solve(X, y);
-	return wrap( coefa);
+	return wrap(coefa);
     }
 
     void compute_and_apply_weights(SEXP wR) {
-	int n = yw.size(), p = xw.ncol(), np = n*p;
+	int n = x.nrow(), p = x.ncol();
+	wrt = sqrt(check_weights(wR));
+	yw = y * wrt;
+	xw = x * rep(wrt, p);
+    }
+
+    NumericVector check_weights(SEXP wR) {
+	//BEGIN_RCPP
 	NumericVector w(wR);
-	NumericIterator ir = wrt.begin(), iw = w.begin(),
-	    ix = x.begin(), ixw = xw.begin(),
-	    iy = y.begin(), iyw = yw.begin();
-	for (int i = 0; i < n; i++) {
-	    ir[i] = sqrt(iw[i]);
-	    iyw[i] = iy[i] * ir[i];
-	}
-	for (int ij = 0; ij < np; )
-	    for(int i = 0; i < n; i++) {
-		ixw[ij] = ix[ij] * ir[i];
-		ij++;
-	    }
+	if(w.size() != y.size())
+	    throw std::invalid_argument("Weight vector wrong length");
+	if(as<bool>(any( is_na(w) )))
+	    throw std::domain_error("Missing values not allowed in weights");
+	if(as<bool>(any(w < -0.0)))
+	    throw std::domain_error("Negative weights found");
+	return w;
+	//END_RCPP
     }
-    
+
 };
 
 RCPP_MODULE(mIrwls) {
@@ -53,13 +58,6 @@
 class_<Irwls>( "cppIrwls" )
 
     .constructor(init_2<NumericMatrix,NumericVector>())
-
-    .field("x", &Irwls::x)
-    .field("y", &Irwls::y)
-    .field("xw", &Irwls::xw)
-    .field("yw", &Irwls::yw)
-    .field("wrt", &Irwls::wrt)
-
     .method("fit", &Irwls::fit)
     ;
 

Modified: pkg/Irwls/tests/trivial.R
===================================================================
--- pkg/Irwls/tests/trivial.R	2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/tests/trivial.R	2010-11-06 22:02:50 UTC (rev 2412)
@@ -12,3 +12,12 @@
 coef = irxly$fit(wt)
 coef2 = lm(y ~ x + I(x^2), weights = wt)$coef
 stopifnot(all.equal(as.vector(coef), as.vector(coef2)))
+wt[1] <- -1
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("negative", msg$message, ignore.case = TRUE))
+wt <- wt[-1]
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("length", msg$message, ignore.case = TRUE))
+wt <- c(NA, wt)
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("missing", msg$message, ignore.case = TRUE))



More information about the Rcpp-commits mailing list