[Rcpp-commits] r2412 - in pkg/Irwls: . man src tests
noreply at r-forge.r-project.org
noreply at r-forge.r-project.org
Sat Nov 6 23:02:50 CET 2010
Author: jmc
Date: 2010-11-06 23:02:50 +0100 (Sat, 06 Nov 2010)
New Revision: 2412
Added:
pkg/Irwls/man/mIrwls-module.Rd
Modified:
pkg/Irwls/NAMESPACE
pkg/Irwls/src/irwls.cpp
pkg/Irwls/tests/trivial.R
Log:
use sugar in fit(), add minimal documentaiton
Modified: pkg/Irwls/NAMESPACE
===================================================================
--- pkg/Irwls/NAMESPACE 2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/NAMESPACE 2010-11-06 22:02:50 UTC (rev 2412)
@@ -1,3 +1,4 @@
useDynLib(Irwls)
exportPattern("^[[:alpha:]]+")
+importFrom(Rcpp, Module)
importClassesFrom( Rcpp, "C++Object", "C++Class", "Module" )
Added: pkg/Irwls/man/mIrwls-module.Rd
===================================================================
--- pkg/Irwls/man/mIrwls-module.Rd (rev 0)
+++ pkg/Irwls/man/mIrwls-module.Rd 2010-11-06 22:02:50 UTC (rev 2412)
@@ -0,0 +1,19 @@
+\name{mIrwls}
+\alias{mIrwls}
+\title{
+ Rcpp module: mIrwls
+}
+\description{
+ Rcpp module for iterative weighted least-squares
+}
+\details{
+ The module contains the following items:
+
+
+
+ classes: \describe{
+ \item{cppIrwls}{ ~~ description of class cppIrwls ~~ }
+ }
+}
+
+\keyword{datasets}
Modified: pkg/Irwls/src/irwls.cpp
===================================================================
--- pkg/Irwls/src/irwls.cpp 2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/src/irwls.cpp 2010-11-06 22:02:50 UTC (rev 2412)
@@ -5,47 +5,52 @@
class Irwls{
public:
- NumericMatrix x, xw;
- NumericVector y, yw, wrt;
-
Irwls(SEXP xr, SEXP yr) {
x = NumericMatrix(xr);
y = NumericVector(yr);
- xw = NumericMatrix(x.nrow(), x.ncol());
+ xw = NumericVector(x.nrow()*x.ncol());
yw = NumericVector(y.size());
wrt = NumericVector(y.size());
}
- SEXP fit(SEXP wR) {
+ NumericVector fit(SEXP wR) {
compute_and_apply_weights(wR);
return do_fit();
}
+
+private:
+ NumericMatrix x;
+ NumericVector y, yw, wrt, xw;
+
- SEXP do_fit() {
- int n = xw.nrow(), p = xw.ncol();
+ NumericVector do_fit() {
+ int n = x.nrow(), p = x.ncol();
arma::mat X(xw.begin(), n, p, false);
arma::colvec y(yw.begin(), n, false);
arma::vec coefa = arma::solve(X, y);
- return wrap( coefa);
+ return wrap(coefa);
}
void compute_and_apply_weights(SEXP wR) {
- int n = yw.size(), p = xw.ncol(), np = n*p;
+ int n = x.nrow(), p = x.ncol();
+ wrt = sqrt(check_weights(wR));
+ yw = y * wrt;
+ xw = x * rep(wrt, p);
+ }
+
+ NumericVector check_weights(SEXP wR) {
+ //BEGIN_RCPP
NumericVector w(wR);
- NumericIterator ir = wrt.begin(), iw = w.begin(),
- ix = x.begin(), ixw = xw.begin(),
- iy = y.begin(), iyw = yw.begin();
- for (int i = 0; i < n; i++) {
- ir[i] = sqrt(iw[i]);
- iyw[i] = iy[i] * ir[i];
- }
- for (int ij = 0; ij < np; )
- for(int i = 0; i < n; i++) {
- ixw[ij] = ix[ij] * ir[i];
- ij++;
- }
+ if(w.size() != y.size())
+ throw std::invalid_argument("Weight vector wrong length");
+ if(as<bool>(any( is_na(w) )))
+ throw std::domain_error("Missing values not allowed in weights");
+ if(as<bool>(any(w < -0.0)))
+ throw std::domain_error("Negative weights found");
+ return w;
+ //END_RCPP
}
-
+
};
RCPP_MODULE(mIrwls) {
@@ -53,13 +58,6 @@
class_<Irwls>( "cppIrwls" )
.constructor(init_2<NumericMatrix,NumericVector>())
-
- .field("x", &Irwls::x)
- .field("y", &Irwls::y)
- .field("xw", &Irwls::xw)
- .field("yw", &Irwls::yw)
- .field("wrt", &Irwls::wrt)
-
.method("fit", &Irwls::fit)
;
Modified: pkg/Irwls/tests/trivial.R
===================================================================
--- pkg/Irwls/tests/trivial.R 2010-11-06 21:56:06 UTC (rev 2411)
+++ pkg/Irwls/tests/trivial.R 2010-11-06 22:02:50 UTC (rev 2412)
@@ -12,3 +12,12 @@
coef = irxly$fit(wt)
coef2 = lm(y ~ x + I(x^2), weights = wt)$coef
stopifnot(all.equal(as.vector(coef), as.vector(coef2)))
+wt[1] <- -1
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("negative", msg$message, ignore.case = TRUE))
+wt <- wt[-1]
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("length", msg$message, ignore.case = TRUE))
+wt <- c(NA, wt)
+msg <- tryCatch(irxly$fit(wt), error = function(e)e)
+stopifnot(is(msg,"error"), grepl("missing", msg$message, ignore.case = TRUE))
More information about the Rcpp-commits
mailing list