[Depmix-commits] r322 - in trunk: . R

noreply at r-forge.r-project.org noreply at r-forge.r-project.org
Wed Jan 20 14:42:51 CET 2010


Author: maarten
Date: 2010-01-20 14:42:51 +0100 (Wed, 20 Jan 2010)
New Revision: 322

Modified:
   trunk/NAMESPACE
   trunk/R/allGenerics.R
   trunk/R/responseMVN.R
Log:
fixed bug (#787) in MVNresponse

Modified: trunk/NAMESPACE
===================================================================
--- trunk/NAMESPACE	2010-01-20 12:31:56 UTC (rev 321)
+++ trunk/NAMESPACE	2010-01-20 13:42:51 UTC (rev 322)
@@ -45,7 +45,8 @@
 	depmix,
 	mix,
 	posterior,
-	GLMresponse,
+	GLMresponse,
+	MVNresponse,
 	transInit,
 	setpars,
 	getpars,

Modified: trunk/R/allGenerics.R
===================================================================
--- trunk/R/allGenerics.R	2010-01-20 12:31:56 UTC (rev 321)
+++ trunk/R/allGenerics.R	2010-01-20 13:42:51 UTC (rev 322)
@@ -15,7 +15,9 @@
 		respstart=NULL,trstart=NULL,instart=NULL,ntimes=NULL, ...) standardGeneric("depmix"))
 
 setGeneric("GLMresponse", function(formula, data = NULL, family = gaussian(), pstart =
-                 NULL, fixed = NULL, prob=TRUE, ...) standardGeneric("GLMresponse"))
+                 NULL, fixed = NULL, prob=TRUE, ...) standardGeneric("GLMresponse"))
+                 
+setGeneric("MVNresponse", function(formula,data,pstart=NULL,fixed=NULL,...) standardGeneric("MVNresponse"))
 
 setGeneric("transInit", function(formula, nstates, data = NULL, family = multinomial(),
                  pstart = NULL, fixed = NULL, prob=TRUE, ...) standardGeneric("transInit"))

Modified: trunk/R/responseMVN.R
===================================================================
--- trunk/R/responseMVN.R	2010-01-20 12:31:56 UTC (rev 321)
+++ trunk/R/responseMVN.R	2010-01-20 13:42:51 UTC (rev 322)
@@ -16,7 +16,7 @@
 	}
 )
 
-dm_dmvnorm <- function(y,mean,sigma,log=FALSE,logdet,invSigma) {
+dm_dmvnorm <- function(x,mean,sigma,log=FALSE,logdet,invSigma) {
   # taken from mvtnorm package
   # allows passing of logdet (sigma) and invsigma to save 
   # computation when called repeated times with same sigma 
@@ -25,33 +25,39 @@
     }
     if (missing(mean)) {
         mean <- rep(0, length = ncol(x))
-    }
-    if (missing(sigma)) {
-        sigma <- diag(ncol(x))
-    }
-    if (NCOL(x) != NCOL(sigma)) {
-        stop("x and sigma have non-conforming size")
-    }
-    if (NROW(sigma) != NCOL(sigma)) {
-        stop("sigma must be a square matrix")
-    }
-    if (length(mean) != NROW(sigma)) {
-        stop("mean and sigma have non-conforming size")
-    }
-    if(missing(invSigma)) {
-      distval <- mahalanobis(x, center = mean, cov = sigma) 
-    } else {
-      if (NCOL(x) != NCOL(invSigma)) {
-        stop("x and invSigma have non-conforming size")
-      }
-      if (NROW(invSigma) != NCOL(invSigma)) {
-          stop("invSigma must be a square matrix")
-      }
-      if (length(mean) != NROW(invSigma)) {
-          stop("mean and invSigma have non-conforming size")
-      }
-      distval <- mahalanobis(x, center = mean, cov = invSigma, inverted=TRUE)
-    } 
+    }
+    if(missing(invSigma)) {
+    	if (missing(sigma)) {
+        	sigma <- diag(ncol(x))
+    	}
+    	invSigma <- solve(sigma)
+    }
+	# check consistency
+	
+	if (NCOL(x) != NCOL(invSigma)) {
+	    stop("x and sigma have non-conforming size")
+	}
+	if (NROW(invSigma) != NCOL(invSigma)) {
+	    stop("sigma must be a square matrix")
+	}
+	if (NCOL(invSigma) != NCOL(mean)) {
+		stop("mean and sigma have non-conforming size")
+	}
+	if(NROW(mean) == NROW(x)) {
+		# varying means
+		
+		# from "mahalanobis":    
+		x <- as.matrix(x) - mean
+    	distval <- rowSums((x %*% invSigma) * x)
+    	#names(retval) <- rownames(x)
+   	 	#retval
+	} else {
+		# constant mean
+		if (length(mean) != NROW(invSigma)) {
+		    stop("mean and sigma have non-conforming size")
+		}
+		distval <- mahalanobis(x, center = mean, cov = invSigma, inverted=TRUE)
+	}	
     if(missing(logdet)) logdet <- sum(log(eigen(sigma, symmetric = TRUE, only.values = TRUE)$values))
     logretval <- -(ncol(x) * log(2 * pi) + logdet + distval)/2
     if (log) {
@@ -72,6 +78,13 @@
 	function(object,log=FALSE,...) {
 		dm_dmvnorm(x=object at y,mean=predict(object),sigma=object at parameters$Sigma,log=log,...)
 	}
+)
+
+
+setMethod("predict","MVNresponse",
+	function(object) {
+		object at x%*%object at parameters$coefficients
+	}
 )
 
 setMethod("simulate",signature(object="MVNresponse"),
@@ -88,37 +101,39 @@
     #if(nsim > 1) response <- array(response,dim=c(nt,ncol(response),nsim))
     return(response)
   }
-)
-
-MVNresponse <- function(formula,data,pstart=NULL,fixed=NULL,...) {
-	call <- match.call()
-	mf <- match.call(expand.dots = FALSE)
-	m <- match(c("formula", "data"), names(mf), 0)
-	mf <- mf[c(1, m)]
-	mf$drop.unused.levels <- TRUE
-	mf[[1]] <- as.name("model.frame")
-	mf <- eval(mf, parent.frame())
-	x <- model.matrix(attr(mf, "terms"),mf)
-	y <- model.response(mf)
-	if(!is.matrix(y)) y <- matrix(y,ncol=1)
-	parameters <- list()
-	parameters$coefficients <- matrix(0.0,ncol=ncol(y),nrow=ncol(x))
-	parameters$Sigma <- diag(ncol(y))
-	npar <- length(unlist(parameters))
-	if(is.null(fixed)) fixed <- as.logical(rep(0,npar))
-	if(!is.null(pstart)) {
-		if(length(pstart)!=npar) stop("length of 'pstart' must be",npar)
-		parameters$coefficients[1,] <- pstart[1:ncol(parameters$coefficients)]
-		pstart <- matrix(pstart,ncol(x),byrow=TRUE)
-		if(ncol(x)>1) parameters$coefficients[2:ncol(x),] <- pstart[2:ncol(x),]
-		if(length(unlist(parameters))>length(parameters$coefficients)) {
-			tmp <- as.numeric(pstart[(length(parameters$coefficients)+1):length(pstart)])
-			if(length(tmp) == ncol(parameters$Sigma)) parameters$Sigma <- diag(tmp) else parameters$Sigma <- matrix(tmp,ncol=ncol(y),nrow=ncol(y))
+)
+setMethod("MVNresponse",
+	signature(formula="formula"),
+	function(formula,data,pstart=NULL,fixed=NULL,...) {
+		call <- match.call()
+		mf <- match.call(expand.dots = FALSE)
+		m <- match(c("formula", "data"), names(mf), 0)
+		mf <- mf[c(1, m)]
+		mf$drop.unused.levels <- TRUE
+		mf[[1]] <- as.name("model.frame")
+		mf <- eval(mf, parent.frame())
+		x <- model.matrix(attr(mf, "terms"),mf)
+		y <- model.response(mf)
+		if(!is.matrix(y)) y <- matrix(y,ncol=1)
+		parameters <- list()
+		parameters$coefficients <- matrix(0.0,ncol=ncol(y),nrow=ncol(x))
+		parameters$Sigma <- diag(ncol(y))
+		npar <- length(unlist(parameters))
+		if(is.null(fixed)) fixed <- as.logical(rep(0,npar))
+		if(!is.null(pstart)) {
+			if(length(pstart)!=npar) stop("length of 'pstart' must be",npar)
+			parameters$coefficients[1,] <- pstart[1:ncol(parameters$coefficients)]
+			pstart <- matrix(pstart,ncol(x),byrow=TRUE)
+			if(ncol(x)>1) parameters$coefficients[2:ncol(x),] <- pstart[2:ncol(x),]
+			if(length(unlist(parameters))>length(parameters$coefficients)) {
+				tmp <- as.numeric(pstart[(length(parameters$coefficients)+1):length(pstart)])
+				if(length(tmp) == ncol(parameters$Sigma)) parameters$Sigma <- diag(tmp) else parameters$Sigma <- matrix(tmp,ncol=ncol(y),nrow=ncol(y))
+			}
 		}
+		mod <- new("MVNresponse",formula=formula,parameters=parameters,fixed=fixed,x=x,y=y,npar=npar)
+		mod
 	}
-	mod <- new("MVNresponse",formula=formula,parameters=parameters,fixed=fixed,x=x,y=y,npar=npar)
-	mod
-}
+)
 
 setMethod("show","MVNresponse",
 	function(object) {
@@ -167,4 +182,4 @@
 		)
 		return(pars)
 	}
-)
+)



More information about the depmix-commits mailing list